Transformers documentation

Phi4 Multimodal

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.50.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Phi4 Multimodal

Overview

Phi4 Multimodal is a lightweight open multimodal foundation model that leverages the language, vision, and speech research and datasets used for Phi-3.5 and 4.0 models. The model processes text, image, and audio inputs, generating text outputs, and comes with 128K token context length. The model underwent an enhancement process, incorporating both supervised fine-tuning, direct preference optimization and RLHF (Reinforcement Learning from Human Feedback) to support precise instruction adherence and safety measures. The languages that each modal supports are the following:

  • Text: Arabic, Chinese, Czech, Danish, Dutch, English, Finnish, French, German, Hebrew, Hungarian, Italian, Japanese, Korean, Norwegian, Polish, Portuguese, Russian, Spanish, Swedish, Thai, Turkish, Ukrainian
  • Vision: English
  • Audio: English, Chinese, German, French, Italian, Japanese, Spanish, Portuguese

This model was contributed by Cyril Vallez. The most recent code can be found here.

Usage tips

Phi4-multimodal-instruct can be found on the Huggingface Hub

In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).

import requests
import torch
import os
import io
from PIL import Image
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from urllib.request import urlopen


# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"
device = "cuda:0"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device,  torch_dtype=torch.float16)

# Optional: load the adapters (note that without them, the base model will very likely not work well)
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})

# Define prompt structure
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'

# Part 1: Image Processing
model.set_adapter("vision") # if loaded, activate the vision adapter
print("\n--- IMAGE PROCESSING ---")
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')

# Download and open image
image = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)

# Generate response
generate_ids = model.generate(
    **inputs,
    max_new_tokens=1000,
    do_sample=False,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f'>>> Response\n{response}')

# Part 2: Audio Processing
model.set_adapter("speech") # if loaded, activate the speech adapter
print("\n--- AUDIO PROCESSING ---")
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation."
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')

# Downlowd and open audio file
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read()))

# Process with the model
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device)

generate_ids = model.generate(
    **inputs,
    max_new_tokens=1000,
    do_sample=False,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f'>>> Response\n{response}')

Phi4MultimodalFeatureExtractor

class transformers.Phi4MultimodalFeatureExtractor

< >

( feature_size: int = 80 sampling_rate: int = 16000 hop_length: int = 160 n_fft: int = 512 win_length: int = 400 preemphasis: float = 0.97 padding_value: float = 0.0 audio_compression_rate: int = 8 audio_downsample_rate: int = 1 audio_feat_stride: int = 1 mel_min_frequency: float = 0 mel_max_frequency: float = 7690 **kwargs )

Phi4MultimodalImageProcessorFast

class transformers.Phi4MultimodalImageProcessorFast

< >

( **kwargs: typing_extensions.Unpack[transformers.models.phi4_multimodal.image_processing_phi4_multimodal_fast.Phi4MultimodalFastImageProcessorKwargs] )

Constructs a Phi4Multimodal image processor.

pad_to_max_num_crops

< >

( images max_crops = 5 )

images: B x 3 x H x W, B<=max_crops

preprocess

< >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None )

Parameters

  • images (ImageInput) — Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set do_rescale=False.
  • image_mean (float or List[float], optional, defaults to self.image_mean) — Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
  • image_std (float or List[float], optional, defaults to self.image_std) — Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
  • return_tensors (str or TensorType, optional) — The type of tensors to return. Can be one of:
    • Unset: Return a list of np.ndarray.
    • TensorType.TENSORFLOW or 'tf': Return a batch of type tf.Tensor.
    • TensorType.PYTORCH or 'pt': Return a batch of type torch.Tensor.
    • TensorType.NUMPY or 'np': Return a batch of type np.ndarray.
    • TensorType.JAX or 'jax': Return a batch of type jax.numpy.ndarray.

Phi4MultimodalProcessor

class transformers.Phi4MultimodalProcessor

< >

( image_processor audio_processor tokenizer fake_image_token_pattern: str = '<\\|image_\\d+\\|>' fake_audio_token_pattern: str = '<\\|audio_\\d+\\|>' **kwargs )

Parameters

  • image_processor (Phi4MultimodalImageProcessorFast) — The image processor to use for images.
  • audio_processor (Phi4MultimodalFeatureExtractor) — The audio processor to use for audio inputs.
  • tokenizer (GPT2TokenizerFast) — The tokenizer to use for text.
  • fake_image_token_pattern (str, optional, defaults to r"<\|image_\d+\|>") — The fake image token pattern.
  • fake_audio_token_pattern (str, optional, defaults to r"<\|audio_\d+\|>") — The fake audio token pattern.

Constructs a Phi4Multimodal processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor.

Phi4MultimodalProcessor offers all the functionalities of Phi4MultimodalImageProcessorFast and GPT2Tokenizer. See the __call__() and decode() for more information.

batch_decode

< >

( *args **kwargs )

This method forwards all its arguments to GPT2Tokenizer’s batch_decode(). Please refer to the docstring of this method for more information.

decode

< >

( *args **kwargs )

This method forwards all its arguments to GPT2Tokenizer’s decode(). Please refer to the docstring of this method for more information.

Phi4MultimodalAudioConfig

class transformers.Phi4MultimodalAudioConfig

< >

( hidden_size: int = 1024 intermediate_size: int = 1536 num_blocks: int = 24 num_attention_heads: int = 16 activation: str = 'swish' chunk_size: int = -1 left_chunk: int = 18 dropout_rate: float = 0.0 ext_pw_out_channel: int = 1024 depthwise_seperable_out_channel: int = 1024 depthwise_multiplier: int = 1 kernel_size: int = 3 conv_activation: str = 'swish' input_size: int = 80 conv_glu_type: str = 'swish' time_reduction: int = 8 bias_max_distance: int = 1000 bias_symmetric: bool = False nemo_activation: str = 'relu' nemo_conv_channels: int = 1024 downsample_rate: int = 1 initializer_range: float = 0.02 audio_token_id: int = 200011 feature_layer: int = -2 **kwargs )

Parameters

  • hidden_size (int, optional, defaults to 1024) — Dimensionality of the encoder layers.
  • intermediate_size (int, optional, defaults to 1536) — Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
  • num_blocks (int, optional, defaults to 24) — Number of hidden layers in the Transformer encoder.
  • num_attention_heads (int, optional, defaults to 16) — Number of attention heads for each attention layer in the Transformer encoder.
  • activation (str, optional, defaults to "swish") — The non-linear activation function in the MLPs.
  • chunk_size (int, optional, defaults to -1) — The chunk size to create the masks.
  • left_chunk (int, optional, defaults to 18) — The left chunk to create the masks.
  • dropout_rate (float, optional, defaults to 0.0) — The dropout ratio.
  • ext_pw_out_channel (int, optional, defaults to 1024) — Number of out channels in the point-wise conv modules.
  • depthwise_seperable_out_channel (int, optional, defaults to 1024) — Number of out channels in the depth-wise separable conv modules.
  • depthwise_multiplier (int, optional, defaults to 1) — Input size multiplier for the depth-wise separable conv modules.
  • kernel_size (int, optional, defaults to 3) — Kernel size for the depth-wise separable conv modules.
  • conv_activation (str, optional, defaults to "swish") — The non-linear activation function in the conv modules.
  • input_size (int, optional, defaults to 80) — Input size for the audio model.
  • conv_glu_type (str, optional, defaults to "swish") — The non-linear activation function in the point-wise conv modules.
  • time_reduction (int, optional, defaults to 8) — Time reduction (subsampling factor).
  • bias_max_distance (int, optional, defaults to 1000) — Max distance for the relative attention bias module.
  • bias_symmetric (bool, optional, defaults to False) — Whether the relative attention bias should be symmetric or not.
  • nemo_activation (str, optional, defaults to "relu") — The non-linear activation function in the nemo conv modules.
  • nemo_conv_channels (int, optional, defaults to 1024) — Number of channels in the nemo conv modules.
  • downsample_rate (int, optional, defaults to 1) — Downsample rate for the audio feature extractor.
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • audio_token_id (int, optional, defaults to 200011) — The audio token id.
  • feature_layer (int, optional, defaults to -2) — The index of the layer of the encoder from which to extract audio features.

This is the configuration class to store the configuration of a Phi4MultimodalAudioModel. It is used to instantiate a Phi4Multimodal audio encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the audio encoder of microsoft/Phi-4-multimodal-instruct architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import Phi4MultimodalAudioConfig

>>> # Initializing a Phi4MultimodalAudioConfig with microsoft/Phi-4-multimodal-instruct style configuration
>>> configuration = Phi4MultimodalAudioConfig()

Phi4MultimodalVisionConfig

class transformers.Phi4MultimodalVisionConfig

< >

( hidden_size = 1152 intermediate_size = 4304 num_hidden_layers = 27 num_attention_heads = 16 num_channels = 3 image_size = 448 patch_size = 14 hidden_act = 'gelu_pytorch_tanh' layer_norm_eps = 1e-06 attention_dropout = 0.0 crop_size: int = 448 image_token_id: int = 200010 feature_layer: int = -2 **kwargs )

Parameters

  • hidden_size (int, optional, defaults to 1152) — Dimensionality of the encoder layers and the pooler layer.
  • intermediate_size (int, optional, defaults to 4304) — Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.
  • num_hidden_layers (int, optional, defaults to 27) — Number of hidden layers in the Transformer encoder.
  • num_attention_heads (int, optional, defaults to 16) — Number of attention heads for each attention layer in the Transformer encoder.
  • num_channels (int, optional, defaults to 3) — Number of channels in the input images.
  • image_size (int, optional, defaults to 448) — The size (resolution) of each image.
  • patch_size (int, optional, defaults to 14) — The size (resolution) of each patch.
  • hidden_act (str or function, optional, defaults to "gelu_pytorch_tanh") — The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "selu" and "gelu_new" "quick_gelu" are supported.
  • layer_norm_eps (float, optional, defaults to 1e-06) — The epsilon used by the layer normalization layers.
  • attention_dropout (float, optional, defaults to 0.0) — The dropout ratio for the attention probabilities.
  • crop_size (int, optional, defaults to 448) — Crop size for the input images.
  • image_token_id (int, optional, defaults to 200010) — The image token id.
  • feature_layer (int, optional, defaults to -2) — The index of the layer of the encoder from which to extract image features.

This is the configuration class to store the configuration of a Phi4MultimodalVisionModel. It is used to instantiate a Phi4Multimodal vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the vision encoder of microsoft/Phi-4-multimodal-instruct architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import Phi4MultimodalVisionConfig

>>> # Initializing a Phi4MultimodalVisionConfig with microsoft/Phi-4-multimodal-instruct style configuration
>>> configuration = Phi4MultimodalVisionConfig()

Phi4MultimodalConfig

class transformers.Phi4MultimodalConfig

< >

( vocab_size = 200064 hidden_size = 3072 intermediate_size = 8192 num_hidden_layers = 32 num_attention_heads = 32 num_key_value_heads = 8 resid_pdrop = 0.0 embd_pdrop = 0.0 attention_dropout = 0.0 hidden_act = 'silu' max_position_embeddings = 131072 initializer_range = 0.02 rms_norm_eps = 1e-05 use_cache = True tie_word_embeddings = False rope_theta = 10000.0 rope_scaling = None partial_rotary_factor = 1 bos_token_id = 199999 eos_token_id = [199999, 200020] pad_token_id = 199999 original_max_position_embeddings = 4096 sliding_window = None vision_config = None audio_config = None **kwargs )

Parameters

  • vocab_size (int, optional, defaults to 200064) — Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling Phi3Model.
  • hidden_size (int, optional, defaults to 3072) — Dimension of the hidden representations.
  • intermediate_size (int, optional, defaults to 8192) — Dimension of the MLP representations.
  • num_hidden_layers (int, optional, defaults to 32) — Number of hidden layers in the Transformer decoder.
  • num_attention_heads (int, optional, defaults to 32) — Number of attention heads for each attention layer in the Transformer decoder.
  • num_key_value_heads (int, optional, defaults to 8) — This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout this paper. If it is not specified, will default to num_attention_heads.
  • resid_pdrop (float, optional, defaults to 0.0) — Dropout probability for mlp outputs.
  • embd_pdrop (int, optional, defaults to 0.0) — The dropout ratio for the embeddings.
  • attention_dropout (float, optional, defaults to 0.0) — The dropout ratio after computing the attention scores.
  • hidden_act (str or function, optional, defaults to "silu") — The non-linear activation function (function or string) in the decoder.
  • max_position_embeddings (int, optional, defaults to 131072) — The maximum sequence length that this model might ever be used with.
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • rms_norm_eps (float, optional, defaults to 1e-05) — The epsilon value used for the RMSNorm.
  • use_cache (bool, optional, defaults to True) — Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True. Whether to tie weight embeddings or not.
  • tie_word_embeddings (bool, optional, defaults to False) — Whether to tie weight embeddings
  • rope_theta (float, optional, defaults to 10000.0) — The base period of the RoPE embeddings.
  • rope_scaling (dict, optional) — The scaling strategy for the RoPE embeddings. If None, no scaling is applied. If a dictionary, it must contain the following keys: type, short_factor and long_factor. The type must be longrope and the short_factor and long_factor must be lists of numbers with the same length as the hidden size divided by the number of attention heads divided by 2.
  • partial_rotary_factor (float, optional, defaults to 1.0) — Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0.
  • bos_token_id (int, optional, defaults to 199999) — The id of the “beginning-of-sequence” token.
  • eos_token_id (int or list[int], optional, defaults to [199999, 200020]) — The id of the “end-of-sequence” token.
  • pad_token_id (int, optional, defaults to 199999) — The id of the padding token.
  • original_max_position_embeddings (int, optional, defaults to 4096) — The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling.
  • sliding_window (int, optional) — Sliding window attention window size. If None, no sliding window is applied.
  • vision_config (Phi4MultimodalVisionConfig or dict, optional) — The vision config for the underlying image embedding model. If not provided, will default to the configuration used to instantiate a model similar in architecture as microsoft/Phi-4-multimodal-instruct.
  • audio_config (Phi4MultimodalAudioConfig or dict, optional) — The audio config for the underlying audio embedding model. If not provided, will default to the configuration used to instantiate a model similar in architecture as microsoft/Phi-4-multimodal-instruct.

This is the configuration class to store the configuration of a Phi4MultimodalModel. It is used to instantiate a Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the microsoft/Phi-4-multimodal-instruct architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import Phi4MultimodalModel, Phi4MultimodalConfig

>>> # Initializing a Phi4Multimodal style configuration
>>> configuration = Phi4MultimodalConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct")

>>> # Initializing a model from the configuration
>>> model = Phi4MultimodalModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

Phi4MultimodalAudioModel

class transformers.Phi4MultimodalAudioModel

< >

( config: Phi4MultimodalAudioConfig )

forward_embeddings

< >

( hidden_states masks )

Forwarding the inputs through the top embedding layers

Phi4MultimodalVisionModel

class transformers.Phi4MultimodalVisionModel

< >

( config: Phi4MultimodalVisionConfig )

Phi4MultimodalModel

class transformers.Phi4MultimodalModel

< >

( config: Phi4MultimodalConfig )

Parameters

  • config (Phi4MultimodalConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
  • config — Phi4MultimodalMMConfig

The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)

This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a Phi4MultimodalMMDecoderLayer

forward

< >

( input_ids: LongTensor = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[typing.List[torch.FloatTensor]] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None image_pixel_values: typing.Optional[torch.FloatTensor] = None image_sizes: typing.Optional[torch.LongTensor] = None image_attention_mask = None audio_input_features: typing.Optional[torch.FloatTensor] = None audio_embed_sizes = None audio_attention_mask = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None **kwargs )

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding indices in input_values. Mask values selected in [0, 1]:

  • position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, config.n_positions - 1].

    What are position IDs?

  • past_key_values (Cache), *optional*) -- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the past_key_valuesreturned by the model at a previous stage of decoding, whenuse_cache=Trueorconfig.use_cache=True`. See our kv cache guide;

    If past_key_values are used, the user can optionally input only the last input_ids (those that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of all input_ids of shape (batch_size, sequence_length).

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • image_pixel_values (torch.FloatTensor, optional) — If the input contains images, these correspond to the pixel values after transformations (as returned by the Processor)
  • image_sizes (torch.LongTensor, optional) — If the input contains images, these correspond to size of each image.
  • image_attention_mask (torch.LongTensor, optional) — Attention mask for the images.
  • audio_input_features (torch.FloatTensor, optional) — If the input contains audio samples, these correspond to the values after transformation (as returned by the Processor).
  • audio_embed_sizes (torch.Tensor, optional) — Size of the audio inputs.
  • audio_attention_mask (`torch.Tensor, optional) — Attention mask for the audio inputs.
  • use_cache (bool, optional) — If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.

The Phi4MultimodalModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Phi4MultimodalForCausalLM

class transformers.Phi4MultimodalForCausalLM

< >

( config )

forward

< >

( input_ids: LongTensor = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[typing.List[torch.FloatTensor]] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None image_pixel_values: typing.Optional[torch.FloatTensor] = None image_sizes: typing.Optional[torch.LongTensor] = None image_attention_mask = None audio_input_features: typing.Optional[torch.FloatTensor] = None audio_embed_sizes = None audio_attention_mask = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None cache_position: typing.Optional[torch.LongTensor] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **kwargs ) transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)

Parameters

  • input_ids (torch.LongTensor of shape (batch_size, sequence_length)) — Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.

    Indices can be obtained using AutoTokenizer. See PreTrainedTokenizer.encode() and PreTrainedTokenizer.call() for details.

    What are input IDs?

  • attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding indices in input_values. Mask values selected in [0, 1]:

  • position_ids (torch.LongTensor of shape (batch_size, sequence_length), optional) — Indices of positions of each input sequence tokens in the position embeddings. Selected in the range [0, config.n_positions - 1].

    What are position IDs?

  • past_key_values (Cache), *optional*) -- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the past_key_valuesreturned by the model at a previous stage of decoding, whenuse_cache=Trueorconfig.use_cache=True`. See our kv cache guide;

    If past_key_values are used, the user can optionally input only the last input_ids (those that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of all input_ids of shape (batch_size, sequence_length).

  • inputs_embeds (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional) — Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert input_ids indices into associated vectors than the model’s internal embedding lookup matrix.
  • image_pixel_values (torch.FloatTensor, optional) — If the input contains images, these correspond to the pixel values after transformations (as returned by the Processor)
  • image_sizes (torch.LongTensor, optional) — If the input contains images, these correspond to size of each image.
  • image_attention_mask (torch.LongTensor, optional) — Attention mask for the images.
  • audio_input_features (torch.FloatTensor, optional) — If the input contains audio samples, these correspond to the values after transformation (as returned by the Processor).
  • audio_embed_sizes (torch.Tensor, optional) — Size of the audio inputs.
  • audio_attention_mask (`torch.Tensor, optional) — Attention mask for the audio inputs.
  • use_cache (bool, optional) — If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • cache_position (torch.LongTensor of shape (sequence_length), optional) — Indices depicting the position of the input sequence tokens in the sequence. Contrarily to position_ids, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length.
  • labels (torch.LongTensor of shape (batch_size, sequence_length), optional) — Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.vocab_size].
  • logits_to_keep (int or torch.Tensor, optional) — If an int, compute logits for the last logits_to_keep tokens. If 0, calculate logits for all input_ids (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a torch.Tensor, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns

transformers.modeling_outputs.CausalLMOutputWithPast or tuple(torch.FloatTensor)

A transformers.modeling_outputs.CausalLMOutputWithPast or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.phi4_multimodal.configuration_phi4_multimodal.Phi4MultimodalConfig'>) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) — Language modeling loss (for next-token prediction).

  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (tuple(tuple(torch.FloatTensor)), optional, returned when use_cache=True is passed or when config.use_cache=True) — Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))

    Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

The Phi4MultimodalForCausalLM forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM
>>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA")
>>> tokenizer = AutoTokenizer.from_pretrained("TBA")
>>> prompt = "This is an example script ."
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
< > Update on GitHub