Image-Text-to-Text
Transformers
Safetensors
English
Helium1_VL_2B
custom_code
Helium1-VL-2B / processing.py
ameroyer's picture
Super-squash branch 'main' using huggingface_hub
1126ea7 verified
# pylint: disable=no-member # avoid weird pylint warnings from SentencePieceProcessor
"""Text and Image processor for CASA models using Qwen2.5_VL image encoder"""
from math import ceil
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload
from typing import cast as type_cast
import torch
import torchvision.transforms.v2 as T
from einops import rearrange
from PIL import Image
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import to_tensor as pil_to_tensor
from torchvision.transforms.v2 import functional as F
from transformers.image_processing_utils import BaseImageProcessor
from transformers.processing_utils import ProcessorMixin
if TYPE_CHECKING:
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
ImageMessage = TypedDict(
"ImageMessage",
{
"type": Literal["image"],
"image": str | Image.Image | None,
},
)
TextMessage = TypedDict(
"TextMessage",
{
"type": Literal["text"],
"text": str,
},
)
MessageContent = list[ImageMessage | TextMessage]
Message = TypedDict(
"Message",
{
"role": Literal["system", "user", "assistant"],
"content": MessageContent,
},
)
ProcessorInput = list[list[Message]] | list[Message]
__INTERP_NAME_TO_MODE__ = {
"nearest": InterpolationMode.NEAREST,
"bilinear": InterpolationMode.BILINEAR,
"bicubic": InterpolationMode.BICUBIC,
"lanczos": InterpolationMode.LANCZOS,
}
__INTERP_INT_TO_MODE__ = {
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
@overload
def universal_resize(
img: Image.Image,
size: tuple[int, int],
interpolation: str | InterpolationMode | int = "bilinear",
antialias: bool = True,
) -> Image.Image: ...
@overload
def universal_resize(
img: torch.Tensor,
size: tuple[int, int],
interpolation: str | InterpolationMode | int = "bilinear",
antialias: bool = True,
) -> torch.Tensor: ...
def universal_resize(
img: Image.Image | torch.Tensor,
size: tuple[int, int],
interpolation: str | InterpolationMode | int = "bilinear",
antialias: bool = True,
) -> Image.Image | torch.Tensor:
"""Resize that works for PIL.Image, CHW tensor, or BCHW tensor"""
if isinstance(interpolation, str):
interpolation = __INTERP_NAME_TO_MODE__[interpolation]
elif isinstance(interpolation, int):
interpolation = __INTERP_INT_TO_MODE__[interpolation]
return F.resize(
img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias
)
@overload
def convert_to_rgb(img: Image.Image) -> Image.Image: ...
@overload
def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ...
def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor:
"""Convert any image to RGB in a way that does not throw PIL warning"""
if isinstance(img, torch.Tensor):
return img
if img.mode == "RGB": # no changes
return img
if img.mode == "P": # palette images need to be converted to RGBA first
return img.convert("RGBA").convert("RGB")
return img.convert("RGB")
class QwenImageProcessor(BaseImageProcessor):
"""Resizing for the Qwen2.5VL encoder. Note that the normalization is
handled in the image_encoder in the model forward"""
def __init__(
self,
img_size: int = 448,
interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic",
max_ratio: int = 10,
round_to_patch_size: int = 56,
use_fast: bool = True,
**kwargs: Any,
) -> None:
# this will also be used in V2llms to determine whether to remove
# the temporal conv
self._num_target_channels = 588
self._merge_size = 2
self._patch_size = 14
super().__init__(
use_fast=use_fast,
do_normalize=False,
**kwargs,
)
self.img_size = img_size
self.interpolation = interpolation
self.max_ratio = max_ratio
self.round_to_patch_size = round_to_patch_size
def resize_transform(
self, img: Image.Image | torch.Tensor, img_size: int | None = None
) -> Image.Image | torch.Tensor:
if img_size is None:
img_size = self.img_size
max_area = img_size**2
if isinstance(img, Image.Image):
img = convert_to_rgb(img)
w_og, h_og = img.size
else:
h_og, w_og = img.shape[-2:]
w, h = w_og, h_og
# Qwen requires max ratio of 10 between max and min sizes
if self.max_ratio > 0:
w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio)
# resize to max area
current_area = w * h
if current_area > max_area:
scale = (max_area / current_area) ** 0.5
w, h = int(w * scale), int(h * scale)
# resize to patch size
if self.round_to_patch_size > 0:
w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size
h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size
# resize
if w != w_og or h != h_og:
img = universal_resize(img, (h, w), self.interpolation)
if isinstance(img, torch.Tensor):
img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img))
return img
def __process_one__(
self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None
) -> torch.Tensor:
"""Same operation as __process_one_with_processor__ but without going through numpy"""
video_or_img = self.resize_transform(video_or_img, img_size)
if isinstance(video_or_img, Image.Image):
video_or_img = pil_to_tensor(video_or_img)
assert isinstance(video_or_img, torch.Tensor)
if video_or_img.ndim == 3:
video_or_img = video_or_img[None]
assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, (
f"Invalid shape {video_or_img.shape}."
)
t, c, h, w = video_or_img.shape
p = self._patch_size
m = self._merge_size
# Convert to RGB
if c == 1:
video_or_img = video_or_img.expand((-1, 3, -1, -1))
if c == 4:
video_or_img = video_or_img[:, :3]
c = video_or_img.shape[1]
assert c == 3, "Expecting RGB image in QwenNormalize"
# Reshape to t h w c' format
h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p
rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m)
video_or_img = rearrange(
video_or_img,
"t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)",
**rearrange_dict,
)
assert video_or_img.shape[-1] == self._num_target_channels, (
f"{video_or_img.shape[-1]} != {self._num_target_channels}"
)
video_or_img = video_or_img.view((-1, h, w, self._num_target_channels))
return video_or_img
@overload
def process_images(
self, image: Image.Image | torch.Tensor, img_size: int | None = None
) -> torch.Tensor: ...
@overload
def process_images(
self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None
) -> list[torch.Tensor]: ...
def process_images(
self,
image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor],
img_size: int | None = None,
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(image, list):
return [self.__process_one__(_x, img_size) for _x in image]
return self.__process_one__(image, img_size)
class ProcessorOutput(dict):
input_ids: torch.Tensor
attention_mask: torch.Tensor
image_embeds_insertion_points: list[torch.Tensor] | None
pixel_values: torch.Tensor | list[torch.Tensor] | None
def to(
self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16
) -> "ProcessorOutput":
return ProcessorOutput(
{
"input_ids": self["input_ids"].to(device),
"attention_mask": self["attention_mask"].to(device),
"image_embeds_insertion_points": self["image_embeds_insertion_points"],
"pixel_values": (
self["pixel_values"].to(dtype).to(device)
if isinstance(self["pixel_values"], torch.Tensor)
else [x.to(dtype).to(device) for x in self["pixel_values"]]
if self["pixel_values"] is not None
else None
),
}
)
class BaseProcessor(ProcessorMixin):
def __init__(
self,
tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer",
pre_image_tokens: tuple[int, ...] = (),
post_image_tokens: tuple[int, ...] = (),
system_start_tokens: tuple[int, ...] = (),
system_end_tokens: tuple[int, ...] = (),
user_start_tokens: tuple[int, ...] = (),
user_end_tokens: tuple[int, ...] = (),
asst_start_tokens: tuple[int, ...] = (),
asst_end_tokens: tuple[int, ...] = (),
allow_system_prompt: bool = True,
pad_token: int = 0,
bos_token: int | None = None,
) -> None:
self.pre_image_tokens = list(pre_image_tokens)
self.post_image_tokens = list(post_image_tokens)
self.system_start_tokens = list(system_start_tokens)
self.system_end_tokens = list(system_end_tokens)
self.user_start_tokens = list(user_start_tokens)
self.user_end_tokens = list(user_end_tokens)
self.asst_start_tokens = list(asst_start_tokens)
self.asst_end_tokens = list(asst_end_tokens)
self._allow_system_prompt = allow_system_prompt
self.tokenizer = tokenizer
self._image_processor = None
self._pad_token = pad_token
self.bos_token = bos_token
@property
def image_processor(self) -> QwenImageProcessor:
assert self._image_processor is not None
return self._image_processor
def _process_content(
self,
message_content: MessageContent,
role: Literal["system", "user", "assistant"],
tokenized_messages: list[torch.Tensor],
insertion_points: list[int],
image_list: list[torch.Tensor | None],
token_count: int,
img_size: int | None = None,
**kwargs: Any,
) -> int:
mapping = {
"user": (self.user_start_tokens, self.user_end_tokens),
"assistant": (self.asst_start_tokens, self.asst_end_tokens),
"system": (self.system_start_tokens, self.system_end_tokens),
}
if role.lower() not in mapping:
raise ValueError(f"Unknown role '{role}' encountered in messages.")
start_tokens, end_tokens = mapping[role.lower()]
# 1) Add the start tokens
if start_tokens:
tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long))
token_count += len(start_tokens)
# 2) Process the message content one by one (potentially interleaved image and text)
for part in message_content:
elt_type = part["type"]
if elt_type == "image":
part = cast(ImageMessage, part)
self._process_image_message(
part,
tokenized_messages,
image_list,
img_size=img_size,
)
token_count += len(self.pre_image_tokens)
insertion_points.append(token_count)
token_count += len(self.post_image_tokens)
else:
part = cast(TextMessage, part)
self._process_text_message(
part["text"],
role=role,
token_list=tokenized_messages,
**kwargs,
)
token_count += tokenized_messages[-1].size(0)
# 3) Add the end tokens
if end_tokens:
tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long))
token_count += len(end_tokens)
return token_count
def _process_text_message(
self,
message: str,
role: Literal["system", "user", "assistant"],
token_list: list[torch.Tensor],
**kwargs: Any,
) -> None:
if role.lower() == "system" and not self._allow_system_prompt:
raise ValueError("System prompts are not allowed in this tokenizer configuration.")
tokens = self.tokenizer.encode(
message, add_special_tokens=False, return_tensors="pt", **kwargs
)
tokens = cast(torch.Tensor, tokens)
token_list.append(tokens.flatten().to(torch.long))
def _process_image_message(
self,
message: ImageMessage,
token_list: list[torch.Tensor],
image_list: list[torch.Tensor | None],
img_size: int | None = None,
) -> None:
img = message["image"]
if img is None:
image_list.append(None)
else:
image_list.append(
self.image_processor.process_images(
self._load_image(img), img_size=img_size
).squeeze(0)
)
if self.pre_image_tokens:
token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long))
if self.post_image_tokens:
token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long))
def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image:
if isinstance(image_path_or_image, str):
return Image.open(image_path_or_image).convert("RGB")
return image_path_or_image
def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor:
return torch.nn.functional.pad(
tokens,
(0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0),
value=pad_value,
)
def pad_tokenized_messages(
self,
tokenized_messages_batch: list[torch.Tensor],
image_insertion_points_batch: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]:
max_len = max(len(x) for x in tokenized_messages_batch)
if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left":
image_insertion_points_batch = [
x + max_len - len(tokenized_messages_batch[idx])
for idx, x in enumerate(image_insertion_points_batch)
]
input_ids = torch.stack(
[
self._maybe_pad(s, max_len - s.size(0), self._pad_token)
for s in tokenized_messages_batch
],
dim=0,
)
attention_mask = torch.stack(
[
self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0)
for s in tokenized_messages_batch
],
dim=0,
)
return input_ids, attention_mask, image_insertion_points_batch
def tokenize_messages(
self,
messages: ProcessorInput,
suppress_bos_token: bool = False,
**kwargs: Any,
) -> ProcessorOutput | None:
"""Tokenize a batch of messages into token IDs suitable for Helium1 CASA model.
Args:
messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages),
where each message is a list of dictionaries with 'role' and 'content' keys.
continue_final_message (bool, optional): If True, the final message in each list will not have an end token added.
Defaults to False.
suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added.
Defaults to False.
**kwargs: Additional keyword arguments passed to the underlying encode method.
"""
if not messages:
return None
if isinstance(messages[0], dict):
messages = [messages] # type: ignore[assignment]
messages = cast(list[list[Message]], messages)
image_insertion_points_batch = []
tokenized_messages_batch = []
image_list: list[torch.Tensor | None] = []
for msgs in messages:
# msgs.append({
# "role": "assistant",
# "content": [{"type": "text", "text": ""}]
# })
tokenized_messages = []
if not suppress_bos_token and self.bos_token is not None:
tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long))
insertion_points = []
token_count = 0
for msg in msgs:
token_count = self._process_content(
msg["content"],
role=msg["role"],
tokenized_messages=tokenized_messages,
insertion_points=insertion_points,
image_list=image_list,
token_count=token_count,
**kwargs,
)
tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long))
image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long))
if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant":
# Remove the assistant end tokens from the final message
end_token_len = len(self.asst_end_tokens)
tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len]
if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user":
# Remove the assistant end tokens from the final message
end_token_len = len(self.asst_end_tokens)
tokenized_messages_batch[-1] = torch.cat(
[
tokenized_messages_batch[-1],
torch.Tensor(self.asst_start_tokens).to(torch.long),
]
)
input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages(
tokenized_messages_batch, image_insertion_points_batch
)
if image_list:
assert sum(img is None for img in image_list) % len(image_list) == 0, (
"Either all or no image must be None."
)
pixel_values: None | torch.Tensor | list[torch.Tensor]
if image_list[0] is None:
pixel_values = None
else:
pixel_values = cast(list[torch.Tensor], image_list)
return ProcessorOutput(
input_ids=input_ids,
image_embeds_insertion_points=image_embeds_insertion_points,
attention_mask=attention_mask,
pixel_values=pixel_values,
)