MOSS-VL-Instruct-0408 / processing_moss_vl.py
CCCCyx's picture
Upload folder using huggingface_hub
b66ac48
# coding=utf-8
# Copyright 2025 The FNLP Vision Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Moss-VL.
"""
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torchvision.transforms.v2 import functional as F
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, SizeDict
from transformers.image_processing_utils_fast import group_images_by_shape, reorder_images
from transformers.utils import TensorType
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
VideosKwargs,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
logger = logging.get_logger(__name__)
class MossVLImageProcessorFast(Qwen2VLImageProcessorFast):
"""
Custom image processor that overrides _preprocess to support multi_image_max_pixels.
Inherits from Qwen2VLImageProcessorFast.
"""
# Multi-image batch total pixels limit (read from config)
multi_image_max_pixels = None
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
):
"""Override _preprocess to use custom smart_resize with batch-level max_pixels.
multi_image_max_pixels is treated as a batch-level total budget, proportionally allocated
to each image based on its original pixel count. min_pixels remains a per-image
constraint. multi_image_max_pixels can be configured separately from longest_edge.
"""
min_pixels = size["shortest_edge"]
max_pixels = size["longest_edge"] # Per-image upper limit
# Use multi_image_max_pixels if configured, otherwise fall back to longest_edge
multi_image_max_pixels = getattr(self, "multi_image_max_pixels", None) or max_pixels
# Calculate total original pixels across all images in the batch
# This is used to proportionally allocate max_pixels to each image
total_original_pixels = sum(img.shape[-2] * img.shape[-1] for img in images)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
height, width = stacked_images.shape[-2:]
if do_resize:
# Calculate proportional max_pixels for images with this shape
# Each image's max_pixels is allocated based on its proportion of total pixels
original_pixels = height * width
if total_original_pixels > 0:
proportion = original_pixels / total_original_pixels
proportional_max_pixels = int(multi_image_max_pixels * proportion)
else:
proportional_max_pixels = multi_image_max_pixels
# Ensure proportional max_pixels is within [min_pixels, max_pixels] range
# min_pixels: per-image lower limit (shortest_edge)
# max_pixels: per-image upper limit (longest_edge)
proportional_max_pixels = max(proportional_max_pixels, min_pixels)
proportional_max_pixels = min(proportional_max_pixels, max_pixels)
resized_height, resized_width = smart_resize(
height,
width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=proportional_max_pixels,
)
stacked_images = self.resize(
image=stacked_images,
size=SizeDict(height=resized_height, width=resized_width),
interpolation=interpolation,
)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Warn if multi-image batch exceeds multi_image_max_pixels due to min_pixels constraint
if len(images) > 1:
total_resized_pixels = sum(img.shape[-2] * img.shape[-1] for img in resized_images)
if total_resized_pixels > multi_image_max_pixels:
logger.warning_once(
f"Multi-image batch total pixels ({total_resized_pixels}) exceeds multi_image_max_pixels ({multi_image_max_pixels}). "
f"This may happen when image_count * min_pixels > multi_image_max_pixels."
)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
processed_grids = {}
for shape, stacked_images in grouped_images.items():
resized_height, resized_width = stacked_images.shape[-2:]
# Fused rescale and normalize
patches = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
if patches.ndim == 4:
# add a temporal dimension if we have images
patches = patches.unsqueeze(1)
if patches.shape[1] % temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=1)
batch_size, grid_t, channel = patches.shape[:3]
grid_t = grid_t // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
patches = patches.view(
batch_size,
grid_t,
temporal_patch_size,
channel,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
patch_size,
)
# Reorder dimensions to group grid and patch information for subsequent flattening.
# (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
flatten_patches = patches.reshape(
batch_size,
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size,
)
processed_images_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_grids = reorder_images(processed_grids, grouped_images_index)
pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.tensor(processed_grids)
return BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
)
def _to_numpy(x):
"""
Convert various tensor types to numpy array.
Supports torch.Tensor, tf.Tensor, jax.Array, np.ndarray, lists, and primitives.
Args:
x: Input value that can be a tensor from various frameworks or a Python primitive
Returns:
np.ndarray: NumPy array representation of the input
"""
# Already numpy
if isinstance(x, np.ndarray):
return x
# Torch tensor or TensorFlow tensor (both have .numpy() method)
if hasattr(x, 'numpy'):
# For torch tensors on CUDA, need to move to CPU first
if hasattr(x, 'cpu'):
return x.cpu().numpy()
# For TensorFlow or already on CPU
return x.numpy()
# JAX arrays and other array-like objects that support __array__ protocol
if hasattr(x, '__array__'):
return np.asarray(x)
# Python primitives (list, tuple, int, float)
return np.array(x)
class MossVLImagesKwargs(ImagesKwargs):
min_pixels: Optional[int]
max_pixels: Optional[int]
patch_size: Optional[int]
temporal_patch_size: Optional[int]
merge_size: Optional[int]
class MossVLVideosKwargs(VideosKwargs, total=False):
video_fps: Optional[Union[int, float]]
min_frames: Optional[int]
max_frames: Optional[int]
num_extract_threads: Optional[int]
class MossVLProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: MossVLImagesKwargs
videos_kwargs: MossVLVideosKwargs
# _defaults = {
# "text_kwargs": {
# "padding": True, # 👈 启用 padding
# "padding_side": "left", # 👈 左 padding
# "pad_to_multiple_of": 8, # 👈 pad 到 8 的倍数
# "return_token_type_ids": False,
# "return_mm_token_type_ids": False,
# },
# "videos_kwargs": {"return_metadata": True},
# }
_defaults = {
"text_kwargs": {
"padding": False,
"return_token_type_ids": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"return_metadata": True},
}
class MossVLProcessor(ProcessorMixin):
r"""
Constructs a Moss-VL processor which wraps a Qwen2VL image processor, Moss-VL video processor and a Qwen2 tokenizer
into a single processor.
[`MossVLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`], [`MossVLVideoProcessor`] and [`Qwen2TokenizerFast`].
See the [`~MossVLProcessor.__call__`] and [`~MossVLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
video_processor ([`MossVLVideoProcessor`], *optional*):
The video processor is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer", "video_processor"]
image_processor_class = "AutoImageProcessor"
video_processor_class = "AutoVideoProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
image_processor=None,
tokenizer=None,
video_processor=None,
chat_template=None,
**kwargs
):
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.vision_start_token = (
"<|vision_start|>" if not hasattr(tokenizer, "vision_start_token") else tokenizer.vision_start_token
)
self.vision_end_token = (
"<|vision_end|>" if not hasattr(tokenizer, "vision_end_token") else tokenizer.vision_end_token
)
# Placeholders used in input text
self.image_placeholder = "<|image|>"
self.video_placeholder = "<|video|>"
self.time_start_token = "<|time_start|>"
self.time_end_token = "<|time_end|>"
# EOS token for labels generation (assistant's response should end with this)
self.im_end_token = "<|im_end|>"
self.im_end_token_id = tokenizer.convert_tokens_to_ids(self.im_end_token)
# Vision-related token ids (all should be masked in labels)
self.vision_start_token_id = tokenizer.convert_tokens_to_ids(self.vision_start_token)
self.vision_end_token_id = tokenizer.convert_tokens_to_ids(self.vision_end_token)
# Token ids that should always be masked in labels (e.g. <|image_pad|>)
self.mask_token_ids = {self.image_token_id}
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
videos: Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]] = None,
labels_spans: Optional[Union[List[tuple], List[List[tuple]]]] = None,
ignore_index: int = -100,
**kwargs: Unpack[MossVLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s)/video(s).
Args:
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded.
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared.
videos (`str`, `Dict`, `list[str]`, `list[Dict]`):
The video or batch of videos to be prepared. Each video can be:
- A string path to a video file
- A dict with keys:
- "video_path": str, path to the video file
- "segments": list of segments, where each segment is:
- [start, end]: a time segment (left-closed, right-open interval in seconds)
- [time]: a single frame at the specified time (in seconds)
The number of segments should match the number of video placeholders in the text.
labels_spans (`list[list[int]]`, `list[list[list[int]]]`, *optional*):
Character-level spans indicating assistant regions in original text.
Each span is a [start, end] list with inclusive start and exclusive end.
Example: [[10, 50], [100, 150]] means characters [10:50) and [100:150) are assistant.
Note: Use list (not tuple) for spans as they will be modified in place during processing.
When provided, the processor will generate `labels` in the output, where:
- Non-assistant tokens have value `ignore_index` (-100 by default)
- Image tokens always have value `ignore_index` even in assistant part
- Other assistant tokens have their token id as label
ignore_index (`int`, *optional*, defaults to -100):
Value for masked positions in labels.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **pixel_values** -- Pixel values to be fed to a model (concatenation of images and videos).
- **grid_thw** -- List of grid sizes (t, h, w) for each media item.
- **media_nums_per_sample** -- List of number of media items per sample.
- **labels** -- (Optional) Labels for training, only present when `labels_spans` is provided.
"""
# Merge kwargs with defaults
output_kwargs = self._merge_kwargs(
MossVLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Step 1: Process images if provided
if images is not None:
images_kwargs = output_kwargs["images_kwargs"].copy()
images_kwargs["return_tensors"] = None
image_inputs = self.image_processor(images=images, **images_kwargs)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
# Step 2: Process videos if provided
if videos is not None:
videos_kwargs = output_kwargs["videos_kwargs"].copy()
videos_kwargs["return_tensors"] = None
videos_inputs = self.video_processor(videos=videos, **videos_kwargs)
video_grid_thw = videos_inputs["video_grid_thw"]
# If user has not requested video metadata, pop it
if "return_metadata" not in kwargs:
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
else:
videos_inputs = {}
video_grid_thw = None
video_metadata = None
# Step 3: Process text with placeholder replacement
if text is None or (isinstance(text, str) and len(text.strip()) == 0):
raise ValueError("Text input is required for MossVL processor and cannot be empty.")
if not isinstance(text, list):
text = [text]
text = text.copy() # Copy to avoid in-place modifications
# Prepare labels_spans if provided
# labels_spans format: List[List[List[int]]] - batch of samples, each sample has multiple spans
# Each span is [start, end] (list, not tuple) so it can be modified in place
should_create_labels = labels_spans is not None
if should_create_labels:
# Ensure batch format: convert single sample spans to batch format
# Single sample: [[start, end], [start, end], ...]
# Batch: [[[start, end], ...], [[start, end], ...], ...]
if labels_spans and isinstance(labels_spans[0], list) and len(labels_spans[0]) == 2 and isinstance(labels_spans[0][0], int):
labels_spans = [labels_spans]
# Step 3.0-pre: Check if we need to reorder (when both images and videos exist)
# If only one media type exists, we can skip the expensive split+reorder+concat
has_images = images is not None and "pixel_values" in image_inputs
has_videos = videos is not None and "pixel_values_videos" in videos_inputs
needs_reorder = has_images and has_videos
image_pixel_values_list = []
video_pixel_values_list = []
# Step 3.0: Record the order of media in original text (before replacement)
# This will be used later to correctly order pixel_values and grid_thw
media_order_per_sample = []
for i in range(len(text)):
media_order = []
temp_text = text[i]
pos = 0
while pos < len(temp_text):
img_pos = temp_text.find(self.image_placeholder, pos)
vid_pos = temp_text.find(self.video_placeholder, pos)
if img_pos == -1 and vid_pos == -1:
break
if img_pos != -1 and (vid_pos == -1 or img_pos < vid_pos):
media_order.append(("image", img_pos))
pos = img_pos + len(self.image_placeholder)
elif vid_pos != -1:
media_order.append(("video", vid_pos))
pos = vid_pos + len(self.video_placeholder)
media_order_per_sample.append(media_order)
# Step 3.0.1: Check if any sample has no media (empty samples need blank image)
# If there are empty samples, we need to enter slow path to handle them properly
has_empty_samples = any(len(order) == 0 for order in media_order_per_sample)
if has_empty_samples:
needs_reorder = True
# Split pixel values for reordering if needed
if needs_reorder:
if has_images:
flat_pixel_values = image_inputs["pixel_values"]
flat_grid_thw = image_inputs["image_grid_thw"]
# grid_thw is (t, h, w), num_patches = t * h * w
patch_counts = [int(np.prod(_to_numpy(grid))) for grid in flat_grid_thw]
if len(patch_counts) == 1:
# Single image case: no need to split
image_pixel_values_list = [flat_pixel_values]
elif len(patch_counts) > 1:
# Multiple images: split by cumulative counts
split_indices = np.cumsum(patch_counts)[:-1]
image_pixel_values_list = np.split(flat_pixel_values, split_indices)
if has_videos:
flat_video_values = videos_inputs["pixel_values_videos"]
flat_video_grid = videos_inputs["video_grid_thw"]
video_patch_counts = [int(np.prod(_to_numpy(grid))) for grid in flat_video_grid]
if len(video_patch_counts) == 1:
# Single video case: no need to split
video_pixel_values_list = [flat_video_values]
elif len(video_patch_counts) > 1:
# Multiple videos: split by cumulative counts
split_indices = np.cumsum(video_patch_counts)[:-1]
video_pixel_values_list = np.split(flat_video_values, split_indices)
# Step 3.1: Replace placeholders (simple replacement, no expansion yet)
# In MossVL, one image placeholder = one image token
# One video placeholder = one video token (will be expanded later)
for i in range(len(text)):
if should_create_labels:
# Replace and update spans for image placeholders
text[i], labels_spans[i] = self._replace_and_update_spans(
text[i], self.image_placeholder, self.image_token, labels_spans[i]
)
# Replace and update spans for video placeholders
text[i], labels_spans[i] = self._replace_and_update_spans(
text[i], self.video_placeholder, self.video_token, labels_spans[i]
)
else:
text[i] = text[i].replace(self.image_placeholder, self.image_token)
text[i] = text[i].replace(self.video_placeholder, self.video_token)
# Step 3.2: Validate token counts
n_images_in_text = [t.count(self.image_token) for t in text]
n_videos_in_text = [t.count(self.video_token) for t in text]
# Count placeholders in text
total_images_in_text = sum(n_images_in_text)
total_videos_in_text = sum(n_videos_in_text)
# Count actual images and videos provided
total_images_provided = len(image_grid_thw) if image_grid_thw is not None else 0
total_videos_provided = len(video_grid_thw) if video_grid_thw is not None else 0
# Validate image counts
if total_images_in_text != total_images_provided:
raise ValueError(
"Number of image tokens does not match number of images provided. "
f"Found {total_images_in_text} image tokens in text and {total_images_provided} images."
)
# Validate video counts
if total_videos_in_text != total_videos_provided:
raise ValueError(
"Number of video tokens does not match number of videos provided. "
f"Found {total_videos_in_text} video tokens in text and {total_videos_provided} videos."
)
# Step 3.3: Expand video tokens with timestamps
# Now expand each video token to multiple tokens (one per frame) with timestamps
if video_grid_thw is not None:
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
metadata = video_metadata[index]
if metadata.fps is None:
logger.warning_once(
"MossVL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
"Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
"Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
)
metadata.fps = 24 if metadata.fps is None else metadata.fps
# Calculate timestamps
# Use actual_timestamps if available (for segments), otherwise use frames_indices
actual_timestamps = getattr(metadata, 'actual_timestamps', None)
curr_timestamp = self._calculate_timestamps(
metadata.frames_indices,
metadata.total_num_frames,
metadata.fps,
metadata.duration,
self.video_processor.temporal_patch_size,
actual_timestamps=actual_timestamps,
)
# Build video placeholder: one video token per frame with timestamp
# video_grid_thw[index][0] is the temporal dimension (number of frames after merging)
video_tokens = []
for frame_idx in range(video_grid_thw[index][0]):
curr_time = curr_timestamp[frame_idx]
# Format: <|time_start|>X.X seconds<|time_end|><|image_pad|>
video_tokens.append(
f"{self.time_start_token}{curr_time:.1f} seconds{self.time_end_token}{self.image_token}"
)
# Wrap the entire video sequence with vision_start and vision_end tokens
video_placeholder = f"{self.vision_start_token}{''.join(video_tokens)}{self.vision_end_token}"
# Replace the video token with expanded sequence and update spans if needed
if should_create_labels:
text[i], labels_spans[i] = self._replace_and_update_spans(
text[i], self.video_token, video_placeholder, labels_spans[i], replace_count=1
)
else:
text[i] = text[i].replace(self.video_token, video_placeholder, 1)
index += 1
# Step 4: Tokenize text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
# Request offset_mapping if we need to create labels
if should_create_labels:
output_kwargs["text_kwargs"]["return_offsets_mapping"] = True
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
# ignore check_special_mm_tokens nums in test and input ids.
# self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
# Create labels if labels_spans was provided
if should_create_labels:
offset_mapping = text_inputs.pop("offset_mapping")
labels = self._create_labels_from_spans(
text_inputs["input_ids"],
offset_mapping,
labels_spans,
ignore_index
)
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
# Step 5: Concatenate pixel_values and grid_thw in sequence order
# Prepare output
output_data = {**text_inputs}
if not needs_reorder:
# Fast path: only one media type, no reordering needed
final_pixel_values = []
final_grid_thw = []
if has_images:
final_pixel_values.append(image_inputs["pixel_values"])
final_grid_thw.extend(image_grid_thw)
if has_videos:
final_pixel_values.append(videos_inputs["pixel_values_videos"])
final_grid_thw.extend(video_grid_thw)
if final_pixel_values:
output_data["pixel_values"] = np.concatenate(final_pixel_values, axis=0) if len(final_pixel_values) > 1 else final_pixel_values[0]
if final_grid_thw:
output_data["grid_thw"] = np.stack(final_grid_thw, axis=0)
# Calculate media_nums_per_sample
media_nums_per_sample = []
for batch_idx in range(len(text)):
media_order = media_order_per_sample[batch_idx]
media_nums_per_sample.append(len(media_order) if len(media_order) > 0 else 1)
# Don't add media_nums_per_sample to output_data yet
# Will add it after BatchFeature to keep it as list
else:
# Slow path: both images and videos exist, need reordering
final_pixel_values = []
final_grid_thw = []
media_nums_per_sample = []
# Global indices to track position in flattened image/video arrays
global_image_idx = 0
global_video_idx = 0
for batch_idx in range(len(text)):
# Use the recorded media order from Step 3.0
media_order = media_order_per_sample[batch_idx]
if len(media_order) == 0:
# If no media provided for this sample, add a blank image
media_nums_per_sample.append(1)
min_pixels = 128 * 128
patch_size = getattr(self.image_processor, "patch_size", None) or 16
temporal_patch_size = getattr(self.image_processor, "temporal_patch_size", None) or 1
merge_size = getattr(self.image_processor, "merge_size", None) or 2
factor = patch_size * merge_size
side = int(np.ceil(np.sqrt(min_pixels) / factor) * factor)
grid_h = side // patch_size
grid_w = side // patch_size
grid_t = 1
# Channel = 3 (RGB)
channel = 3
dim = channel * temporal_patch_size * patch_size * patch_size
num_patches = grid_t * grid_h * grid_w
blank_pixel_values = np.zeros((num_patches, dim), dtype=np.float32)
blank_grid_thw = np.array([grid_t, grid_h, grid_w], dtype=np.int64)
final_pixel_values.append(blank_pixel_values)
final_grid_thw.append(blank_grid_thw)
else:
media_nums_per_sample.append(len(media_order))
# Collect media data according to the recorded order
for media_type, _ in media_order:
if media_type == "image" and image_grid_thw is not None:
# Get image data
if image_pixel_values_list:
final_pixel_values.append(image_pixel_values_list[global_image_idx])
final_grid_thw.append(image_grid_thw[global_image_idx])
global_image_idx += 1
elif media_type == "video" and video_grid_thw is not None:
# Get video data
if video_pixel_values_list:
final_pixel_values.append(video_pixel_values_list[global_video_idx])
final_grid_thw.append(video_grid_thw[global_video_idx])
global_video_idx += 1
# Concatenate/stack to unified format
if final_pixel_values:
output_data["pixel_values"] = np.concatenate(final_pixel_values, axis=0)
if final_grid_thw:
output_data["grid_thw"] = np.stack(final_grid_thw, axis=0)
# Don't add media_nums_per_sample to output_data yet
# Will add it after BatchFeature to keep it as list
# Create cross_attention_mask using media_nums_per_sample
if "input_ids" in output_data and "grid_thw" in output_data and media_nums_per_sample:
cross_attention_mask = self._create_cross_attention_mask(
output_data["input_ids"],
output_data["grid_thw"],
media_nums_per_sample,
output_data.get("attention_mask", None)
)
output_data["cross_attention_mask"] = cross_attention_mask
# Add labels to output if created
if should_create_labels:
output_data["labels"] = labels
# BatchFeature will handle conversion to pt/tf/jax/np based on tensor_type
batch_feature = BatchFeature(data=output_data, tensor_type=return_tensors)
# Add media_nums_per_sample after BatchFeature to keep it as list (not tensor)
if media_nums_per_sample:
batch_feature["media_nums_per_sample"] = media_nums_per_sample
return batch_feature
def _create_cross_attention_mask(self, input_ids, grid_thw, media_nums_per_sample, attention_mask=None):
"""
Create cross_attention_mask of shape (batch_size, 1, text_len, num_images).
Video frames are treated as individual images.
Mask values: True for masked, False for visible.
Causal masking: text can see images that appear at or before the text position.
Args:
input_ids: List of token ids
grid_thw: Grid sizes for each media item
media_nums_per_sample: Number of media items per sample
attention_mask: Optional attention mask to filter out padding positions
"""
batch_size = len(input_ids)
max_text_len = max(len(ids) for ids in input_ids)
# Calculate total frames per sample to find max_num_frames
total_frames_per_sample = []
media_idx = 0
for b in range(batch_size):
num_media = media_nums_per_sample[b]
if num_media == 0:
total_frames_per_sample.append(0)
continue
sample_frames = 0
for _ in range(num_media):
# grid_thw is (N, 3) where first dim is t (num_frames)
t = grid_thw[media_idx][0]
sample_frames += t
media_idx += 1
total_frames_per_sample.append(sample_frames)
max_num_frames = max(total_frames_per_sample) if total_frames_per_sample else 0
if max_num_frames == 0:
return None
# Vectorized implementation for speed
# 1. Pad input_ids to create a tensor
# We use -1 as pad value since token ids are positive
input_ids_tensor = torch.full((batch_size, max_text_len), -1, dtype=torch.long)
for b, ids in enumerate(input_ids):
l = len(ids)
input_ids_tensor[b, :l] = torch.tensor(ids, dtype=torch.long)
# 2. Identify image tokens
is_image_token = (input_ids_tensor == self.image_token_id)
# 3. Compute cumulative image tokens (how many image tokens appeared up to position t)
# shape: (batch_size, text_len)
cum_image_tokens = is_image_token.cumsum(dim=1)
# 4. Create frame indices
# shape: (1, 1, max_num_frames)
frame_indices = torch.arange(max_num_frames).reshape(1, 1, -1)
# 5. Determine visibility based on causal relationship
# Text at `t` sees frame `i` if `cum_image_tokens[t] > i`
# Because if frame `i` is the (i+1)-th image token, it becomes visible when count reaches i+1
# shape: (batch_size, text_len, max_num_frames)
visible_mask = cum_image_tokens.unsqueeze(-1) > frame_indices
# 6. Apply attention_mask if provided
if attention_mask is not None:
# Convert to tensor if needed
if isinstance(attention_mask, torch.Tensor):
attn_mask_tensor = attention_mask
else:
# List of lists
attn_mask_tensor = torch.zeros((batch_size, max_text_len), dtype=torch.long)
for b, mask_row in enumerate(attention_mask):
l = len(mask_row)
attn_mask_tensor[b, :l] = torch.tensor(mask_row, dtype=torch.long)
# shape: (batch_size, text_len, 1)
valid_text = (attn_mask_tensor.unsqueeze(-1) == 1)
visible_mask = visible_mask & valid_text
# 7. Mask out frames that don't exist for a sample
# shape: (batch_size, 1, 1)
total_frames_tensor = torch.tensor(total_frames_per_sample).reshape(batch_size, 1, 1)
# shape: (batch_size, 1, max_num_frames)
valid_frames = frame_indices < total_frames_tensor
visible_mask = visible_mask & valid_frames
# 8. Create final mask (True for masked, False for visible)
mask = ~visible_mask
# 9. Add channel dimension: (batch_size, 1, text_len, max_num_frames)
mask = mask.unsqueeze(1)
return mask
def _replace_and_update_spans(
self,
text: str,
old_str: str,
new_str: str,
spans: List[List[int]],
replace_count: int = -1
) -> tuple:
"""
Replace occurrences of old_str with new_str and update spans accordingly.
Args:
text: The text to perform replacement on
old_str: String to be replaced
new_str: String to replace with
spans: List of [start, end] spans to update (modified in place)
replace_count: Maximum number of replacements (-1 for all)
Returns:
Tuple of (new_text, updated_spans)
"""
delta = len(new_str) - len(old_str)
result_text = text
count = 0
search_start = 0
while True:
pos = result_text.find(old_str, search_start)
if pos == -1:
break
if replace_count != -1 and count >= replace_count:
break
# Update all spans that come after this position
for span in spans:
if span[0] > pos:
# Span starts after replacement point
span[0] += delta
span[1] += delta
elif span[1] > pos:
# Span ends after replacement point (spans the replacement)
span[1] += delta
# Perform the replacement
result_text = result_text[:pos] + new_str + result_text[pos + len(old_str):]
search_start = pos + len(new_str)
count += 1
return result_text, spans
def _create_labels_from_spans(
self,
input_ids: List[List[int]],
offset_mapping: List[List[tuple]],
labels_spans: List[List[List[int]]],
ignore_index: int = -100,
mask_token_ids: Optional[set] = None
) -> List[List[int]]:
"""
Create labels from spans and offset_mapping.
Args:
input_ids: Tokenized input ids
offset_mapping: Character offsets for each token from tokenizer (special tokens included)
labels_spans: Updated spans indicating assistant regions (after text transformations)
ignore_index: Value for masked positions
mask_token_ids: Set of token ids that should always be masked (set to ignore_index)
in labels, regardless of whether they fall inside a span.
Defaults to self.mask_token_ids if not provided.
Returns:
labels: List of label ids, same shape as input_ids
Note:
- Tokenizer's offset_mapping already includes correct offsets for special tokens in text
- Only need to mask tokens inside <|vision_start|>...<|vision_end|>
- Tokens whose id is in mask_token_ids are always masked
- All other tokens in spans (including special tokens like <|im_end|>) get labels
"""
if mask_token_ids is None:
mask_token_ids = self.mask_token_ids
batch_labels = []
for batch_idx in range(len(input_ids)):
ids = input_ids[batch_idx]
offsets = offset_mapping[batch_idx]
spans = labels_spans[batch_idx]
labels = [ignore_index] * len(ids)
# Process each span: find token range and set labels
for span_start, span_end in spans:
in_vision = False
# Find tokens that overlap with this span
for token_idx, (token_id, (char_start, char_end)) in enumerate(zip(ids, offsets)):
# Skip tokens completely before this span
if char_end <= span_start:
continue
# Stop when tokens are completely after this span
if char_start >= span_end:
break
# Token overlaps with span, process it
# Track vision region: <|vision_start|> ... <|vision_end|>
if token_id == self.vision_start_token_id:
in_vision = True
continue
if token_id == self.vision_end_token_id:
in_vision = False
continue
# Skip tokens inside vision region
if in_vision:
continue
# Always mask special tokens that should never have labels
if token_id in mask_token_ids:
continue
# Set label for this token
labels[token_idx] = token_id
batch_labels.append(labels)
return batch_labels
def _calculate_timestamps(
self,
frames_indices: Optional[Union[List[int], np.ndarray]],
total_num_frames: int,
video_fps: float,
duration: float,
merge_size: int = 1,
actual_timestamps: Optional[List[float]] = None
):
"""
Calculate timestamps for video frames.
Args:
frames_indices: Actual frame indices extracted (if available)
total_num_frames: Total number of sampled frames
video_fps: Video frames per second
duration: Video duration in seconds
merge_size: Temporal merge size
actual_timestamps: Pre-calculated actual timestamps (for segments)
Returns:
List of timestamps (one per merged temporal patch)
"""
# If actual timestamps are provided (from segment), use them directly
if actual_timestamps is not None:
timestamps = list(actual_timestamps)
# Pad timestamps to be multiple of merge_size
if len(timestamps) % merge_size != 0:
timestamps.extend([timestamps[-1]] * (merge_size - len(timestamps) % merge_size))
# Frames are merged by merge_size, so we average the timestamps within each temporal patch
timestamps = [
(timestamps[i] + timestamps[i + merge_size - 1]) / 2
for i in range(0, len(timestamps), merge_size)
]
return timestamps
# Use frames_indices if available, otherwise generate uniformly sampled indices
if frames_indices is not None:
if isinstance(frames_indices, np.ndarray):
indices = frames_indices.tolist()
else:
indices = list(frames_indices)
else:
# Generate uniformly sampled frame indices
if total_num_frames <= 1:
indices = [0]
else:
# Uniformly sample frames across the video duration
indices = np.linspace(0, duration * video_fps - 1, total_num_frames).astype(np.int32).tolist()
# Pad indices to be multiple of merge_size
if len(indices) % merge_size != 0:
indices.extend([indices[-1]] * (merge_size - len(indices) % merge_size))
# Convert frame indices to timestamps
timestamps = [idx / video_fps for idx in indices]
# Frames are merged by merge_size, so we average the timestamps within each temporal patch
timestamps = [
(timestamps[i] + timestamps[i + merge_size - 1]) / 2
for i in range(0, len(timestamps), merge_size)
]
return timestamps
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to the tokenizer's batch_decode.
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to the tokenizer's decode.
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor
of shape `(batch_size, sequence_length)` or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode` method.
Returns:
`list[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
__all__ = ["MossVLProcessor", "MossVLImageProcessorFast"]