|
|
|
|
|
"""Text and Image processor for CASA models using Qwen2.5_VL image encoder""" |
|
|
|
|
|
from math import ceil |
|
|
from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload |
|
|
from typing import cast as type_cast |
|
|
|
|
|
import torch |
|
|
import torchvision.transforms.v2 as T |
|
|
from einops import rearrange |
|
|
from PIL import Image |
|
|
from torchvision.transforms import InterpolationMode |
|
|
from torchvision.transforms.functional import to_tensor as pil_to_tensor |
|
|
from torchvision.transforms.v2 import functional as F |
|
|
from transformers.image_processing_utils import BaseImageProcessor |
|
|
from transformers.processing_utils import ProcessorMixin |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer |
|
|
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
|
|
|
|
|
|
|
ImageMessage = TypedDict( |
|
|
"ImageMessage", |
|
|
{ |
|
|
"type": Literal["image"], |
|
|
"image": str | Image.Image | None, |
|
|
}, |
|
|
) |
|
|
|
|
|
TextMessage = TypedDict( |
|
|
"TextMessage", |
|
|
{ |
|
|
"type": Literal["text"], |
|
|
"text": str, |
|
|
}, |
|
|
) |
|
|
|
|
|
MessageContent = list[ImageMessage | TextMessage] |
|
|
|
|
|
Message = TypedDict( |
|
|
"Message", |
|
|
{ |
|
|
"role": Literal["system", "user", "assistant"], |
|
|
"content": MessageContent, |
|
|
}, |
|
|
) |
|
|
|
|
|
ProcessorInput = list[list[Message]] | list[Message] |
|
|
|
|
|
__INTERP_NAME_TO_MODE__ = { |
|
|
"nearest": InterpolationMode.NEAREST, |
|
|
"bilinear": InterpolationMode.BILINEAR, |
|
|
"bicubic": InterpolationMode.BICUBIC, |
|
|
"lanczos": InterpolationMode.LANCZOS, |
|
|
} |
|
|
|
|
|
__INTERP_INT_TO_MODE__ = { |
|
|
0: InterpolationMode.NEAREST, |
|
|
2: InterpolationMode.BILINEAR, |
|
|
3: InterpolationMode.BICUBIC, |
|
|
4: InterpolationMode.BOX, |
|
|
5: InterpolationMode.HAMMING, |
|
|
1: InterpolationMode.LANCZOS, |
|
|
} |
|
|
|
|
|
|
|
|
@overload |
|
|
def universal_resize( |
|
|
img: Image.Image, |
|
|
size: tuple[int, int], |
|
|
interpolation: str | InterpolationMode | int = "bilinear", |
|
|
antialias: bool = True, |
|
|
) -> Image.Image: ... |
|
|
@overload |
|
|
def universal_resize( |
|
|
img: torch.Tensor, |
|
|
size: tuple[int, int], |
|
|
interpolation: str | InterpolationMode | int = "bilinear", |
|
|
antialias: bool = True, |
|
|
) -> torch.Tensor: ... |
|
|
def universal_resize( |
|
|
img: Image.Image | torch.Tensor, |
|
|
size: tuple[int, int], |
|
|
interpolation: str | InterpolationMode | int = "bilinear", |
|
|
antialias: bool = True, |
|
|
) -> Image.Image | torch.Tensor: |
|
|
"""Resize that works for PIL.Image, CHW tensor, or BCHW tensor""" |
|
|
if isinstance(interpolation, str): |
|
|
interpolation = __INTERP_NAME_TO_MODE__[interpolation] |
|
|
elif isinstance(interpolation, int): |
|
|
interpolation = __INTERP_INT_TO_MODE__[interpolation] |
|
|
|
|
|
return F.resize( |
|
|
img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias |
|
|
) |
|
|
|
|
|
|
|
|
@overload |
|
|
def convert_to_rgb(img: Image.Image) -> Image.Image: ... |
|
|
@overload |
|
|
def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ... |
|
|
def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor: |
|
|
"""Convert any image to RGB in a way that does not throw PIL warning""" |
|
|
if isinstance(img, torch.Tensor): |
|
|
return img |
|
|
if img.mode == "RGB": |
|
|
return img |
|
|
if img.mode == "P": |
|
|
return img.convert("RGBA").convert("RGB") |
|
|
return img.convert("RGB") |
|
|
|
|
|
|
|
|
class QwenImageProcessor(BaseImageProcessor): |
|
|
"""Resizing for the Qwen2.5VL encoder. Note that the normalization is |
|
|
handled in the image_encoder in the model forward""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
img_size: int = 448, |
|
|
interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic", |
|
|
max_ratio: int = 10, |
|
|
round_to_patch_size: int = 56, |
|
|
use_fast: bool = True, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
|
|
|
|
|
|
self._num_target_channels = 588 |
|
|
self._merge_size = 2 |
|
|
self._patch_size = 14 |
|
|
super().__init__( |
|
|
use_fast=use_fast, |
|
|
do_normalize=False, |
|
|
**kwargs, |
|
|
) |
|
|
self.img_size = img_size |
|
|
self.interpolation = interpolation |
|
|
self.max_ratio = max_ratio |
|
|
self.round_to_patch_size = round_to_patch_size |
|
|
|
|
|
def resize_transform( |
|
|
self, img: Image.Image | torch.Tensor, img_size: int | None = None |
|
|
) -> Image.Image | torch.Tensor: |
|
|
if img_size is None: |
|
|
img_size = self.img_size |
|
|
max_area = img_size**2 |
|
|
if isinstance(img, Image.Image): |
|
|
img = convert_to_rgb(img) |
|
|
w_og, h_og = img.size |
|
|
else: |
|
|
h_og, w_og = img.shape[-2:] |
|
|
w, h = w_og, h_og |
|
|
|
|
|
|
|
|
if self.max_ratio > 0: |
|
|
w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio) |
|
|
|
|
|
|
|
|
current_area = w * h |
|
|
if current_area > max_area: |
|
|
scale = (max_area / current_area) ** 0.5 |
|
|
w, h = int(w * scale), int(h * scale) |
|
|
|
|
|
|
|
|
if self.round_to_patch_size > 0: |
|
|
w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size |
|
|
h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size |
|
|
|
|
|
|
|
|
if w != w_og or h != h_og: |
|
|
img = universal_resize(img, (h, w), self.interpolation) |
|
|
if isinstance(img, torch.Tensor): |
|
|
img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img)) |
|
|
return img |
|
|
|
|
|
def __process_one__( |
|
|
self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None |
|
|
) -> torch.Tensor: |
|
|
"""Same operation as __process_one_with_processor__ but without going through numpy""" |
|
|
video_or_img = self.resize_transform(video_or_img, img_size) |
|
|
if isinstance(video_or_img, Image.Image): |
|
|
video_or_img = pil_to_tensor(video_or_img) |
|
|
assert isinstance(video_or_img, torch.Tensor) |
|
|
if video_or_img.ndim == 3: |
|
|
video_or_img = video_or_img[None] |
|
|
assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, ( |
|
|
f"Invalid shape {video_or_img.shape}." |
|
|
) |
|
|
t, c, h, w = video_or_img.shape |
|
|
p = self._patch_size |
|
|
m = self._merge_size |
|
|
|
|
|
|
|
|
if c == 1: |
|
|
video_or_img = video_or_img.expand((-1, 3, -1, -1)) |
|
|
if c == 4: |
|
|
video_or_img = video_or_img[:, :3] |
|
|
c = video_or_img.shape[1] |
|
|
assert c == 3, "Expecting RGB image in QwenNormalize" |
|
|
|
|
|
|
|
|
h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p |
|
|
rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m) |
|
|
|
|
|
video_or_img = rearrange( |
|
|
video_or_img, |
|
|
"t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)", |
|
|
**rearrange_dict, |
|
|
) |
|
|
assert video_or_img.shape[-1] == self._num_target_channels, ( |
|
|
f"{video_or_img.shape[-1]} != {self._num_target_channels}" |
|
|
) |
|
|
video_or_img = video_or_img.view((-1, h, w, self._num_target_channels)) |
|
|
|
|
|
return video_or_img |
|
|
|
|
|
@overload |
|
|
def process_images( |
|
|
self, image: Image.Image | torch.Tensor, img_size: int | None = None |
|
|
) -> torch.Tensor: ... |
|
|
@overload |
|
|
def process_images( |
|
|
self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None |
|
|
) -> list[torch.Tensor]: ... |
|
|
def process_images( |
|
|
self, |
|
|
image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor], |
|
|
img_size: int | None = None, |
|
|
) -> torch.Tensor | list[torch.Tensor]: |
|
|
if isinstance(image, list): |
|
|
return [self.__process_one__(_x, img_size) for _x in image] |
|
|
return self.__process_one__(image, img_size) |
|
|
|
|
|
|
|
|
class ProcessorOutput(dict): |
|
|
input_ids: torch.Tensor |
|
|
attention_mask: torch.Tensor |
|
|
image_embeds_insertion_points: list[torch.Tensor] | None |
|
|
pixel_values: torch.Tensor | list[torch.Tensor] | None |
|
|
|
|
|
def to( |
|
|
self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16 |
|
|
) -> "ProcessorOutput": |
|
|
return ProcessorOutput( |
|
|
{ |
|
|
"input_ids": self["input_ids"].to(device), |
|
|
"attention_mask": self["attention_mask"].to(device), |
|
|
"image_embeds_insertion_points": self["image_embeds_insertion_points"], |
|
|
"pixel_values": ( |
|
|
self["pixel_values"].to(dtype).to(device) |
|
|
if isinstance(self["pixel_values"], torch.Tensor) |
|
|
else [x.to(dtype).to(device) for x in self["pixel_values"]] |
|
|
if self["pixel_values"] is not None |
|
|
else None |
|
|
), |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
class BaseProcessor(ProcessorMixin): |
|
|
def __init__( |
|
|
self, |
|
|
tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer", |
|
|
pre_image_tokens: tuple[int, ...] = (), |
|
|
post_image_tokens: tuple[int, ...] = (), |
|
|
system_start_tokens: tuple[int, ...] = (), |
|
|
system_end_tokens: tuple[int, ...] = (), |
|
|
user_start_tokens: tuple[int, ...] = (), |
|
|
user_end_tokens: tuple[int, ...] = (), |
|
|
asst_start_tokens: tuple[int, ...] = (), |
|
|
asst_end_tokens: tuple[int, ...] = (), |
|
|
allow_system_prompt: bool = True, |
|
|
pad_token: int = 0, |
|
|
bos_token: int | None = None, |
|
|
) -> None: |
|
|
self.pre_image_tokens = list(pre_image_tokens) |
|
|
self.post_image_tokens = list(post_image_tokens) |
|
|
self.system_start_tokens = list(system_start_tokens) |
|
|
self.system_end_tokens = list(system_end_tokens) |
|
|
self.user_start_tokens = list(user_start_tokens) |
|
|
self.user_end_tokens = list(user_end_tokens) |
|
|
self.asst_start_tokens = list(asst_start_tokens) |
|
|
self.asst_end_tokens = list(asst_end_tokens) |
|
|
self._allow_system_prompt = allow_system_prompt |
|
|
self.tokenizer = tokenizer |
|
|
self._image_processor = None |
|
|
self._pad_token = pad_token |
|
|
self.bos_token = bos_token |
|
|
|
|
|
@property |
|
|
def image_processor(self) -> QwenImageProcessor: |
|
|
assert self._image_processor is not None |
|
|
return self._image_processor |
|
|
|
|
|
def _process_content( |
|
|
self, |
|
|
message_content: MessageContent, |
|
|
role: Literal["system", "user", "assistant"], |
|
|
tokenized_messages: list[torch.Tensor], |
|
|
insertion_points: list[int], |
|
|
image_list: list[torch.Tensor | None], |
|
|
token_count: int, |
|
|
img_size: int | None = None, |
|
|
**kwargs: Any, |
|
|
) -> int: |
|
|
mapping = { |
|
|
"user": (self.user_start_tokens, self.user_end_tokens), |
|
|
"assistant": (self.asst_start_tokens, self.asst_end_tokens), |
|
|
"system": (self.system_start_tokens, self.system_end_tokens), |
|
|
} |
|
|
if role.lower() not in mapping: |
|
|
raise ValueError(f"Unknown role '{role}' encountered in messages.") |
|
|
start_tokens, end_tokens = mapping[role.lower()] |
|
|
|
|
|
if start_tokens: |
|
|
tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long)) |
|
|
token_count += len(start_tokens) |
|
|
|
|
|
for part in message_content: |
|
|
elt_type = part["type"] |
|
|
if elt_type == "image": |
|
|
part = cast(ImageMessage, part) |
|
|
self._process_image_message( |
|
|
part, |
|
|
tokenized_messages, |
|
|
image_list, |
|
|
img_size=img_size, |
|
|
) |
|
|
token_count += len(self.pre_image_tokens) |
|
|
insertion_points.append(token_count) |
|
|
token_count += len(self.post_image_tokens) |
|
|
else: |
|
|
part = cast(TextMessage, part) |
|
|
self._process_text_message( |
|
|
part["text"], |
|
|
role=role, |
|
|
token_list=tokenized_messages, |
|
|
**kwargs, |
|
|
) |
|
|
token_count += tokenized_messages[-1].size(0) |
|
|
|
|
|
if end_tokens: |
|
|
tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long)) |
|
|
token_count += len(end_tokens) |
|
|
return token_count |
|
|
|
|
|
def _process_text_message( |
|
|
self, |
|
|
message: str, |
|
|
role: Literal["system", "user", "assistant"], |
|
|
token_list: list[torch.Tensor], |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
if role.lower() == "system" and not self._allow_system_prompt: |
|
|
raise ValueError("System prompts are not allowed in this tokenizer configuration.") |
|
|
tokens = self.tokenizer.encode( |
|
|
message, add_special_tokens=False, return_tensors="pt", **kwargs |
|
|
) |
|
|
tokens = cast(torch.Tensor, tokens) |
|
|
token_list.append(tokens.flatten().to(torch.long)) |
|
|
|
|
|
def _process_image_message( |
|
|
self, |
|
|
message: ImageMessage, |
|
|
token_list: list[torch.Tensor], |
|
|
image_list: list[torch.Tensor | None], |
|
|
img_size: int | None = None, |
|
|
) -> None: |
|
|
img = message["image"] |
|
|
if img is None: |
|
|
image_list.append(None) |
|
|
else: |
|
|
image_list.append( |
|
|
self.image_processor.process_images( |
|
|
self._load_image(img), img_size=img_size |
|
|
).squeeze(0) |
|
|
) |
|
|
if self.pre_image_tokens: |
|
|
token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long)) |
|
|
|
|
|
if self.post_image_tokens: |
|
|
token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long)) |
|
|
|
|
|
def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image: |
|
|
if isinstance(image_path_or_image, str): |
|
|
return Image.open(image_path_or_image).convert("RGB") |
|
|
return image_path_or_image |
|
|
|
|
|
def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor: |
|
|
return torch.nn.functional.pad( |
|
|
tokens, |
|
|
(0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0), |
|
|
value=pad_value, |
|
|
) |
|
|
|
|
|
def pad_tokenized_messages( |
|
|
self, |
|
|
tokenized_messages_batch: list[torch.Tensor], |
|
|
image_insertion_points_batch: list[torch.Tensor] | None = None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]: |
|
|
max_len = max(len(x) for x in tokenized_messages_batch) |
|
|
if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left": |
|
|
image_insertion_points_batch = [ |
|
|
x + max_len - len(tokenized_messages_batch[idx]) |
|
|
for idx, x in enumerate(image_insertion_points_batch) |
|
|
] |
|
|
input_ids = torch.stack( |
|
|
[ |
|
|
self._maybe_pad(s, max_len - s.size(0), self._pad_token) |
|
|
for s in tokenized_messages_batch |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
attention_mask = torch.stack( |
|
|
[ |
|
|
self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0) |
|
|
for s in tokenized_messages_batch |
|
|
], |
|
|
dim=0, |
|
|
) |
|
|
return input_ids, attention_mask, image_insertion_points_batch |
|
|
|
|
|
def tokenize_messages( |
|
|
self, |
|
|
messages: ProcessorInput, |
|
|
suppress_bos_token: bool = False, |
|
|
**kwargs: Any, |
|
|
) -> ProcessorOutput | None: |
|
|
"""Tokenize a batch of messages into token IDs suitable for Helium1 CASA model. |
|
|
|
|
|
Args: |
|
|
messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages), |
|
|
where each message is a list of dictionaries with 'role' and 'content' keys. |
|
|
continue_final_message (bool, optional): If True, the final message in each list will not have an end token added. |
|
|
Defaults to False. |
|
|
suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added. |
|
|
Defaults to False. |
|
|
**kwargs: Additional keyword arguments passed to the underlying encode method. |
|
|
""" |
|
|
if not messages: |
|
|
return None |
|
|
if isinstance(messages[0], dict): |
|
|
messages = [messages] |
|
|
|
|
|
messages = cast(list[list[Message]], messages) |
|
|
image_insertion_points_batch = [] |
|
|
tokenized_messages_batch = [] |
|
|
image_list: list[torch.Tensor | None] = [] |
|
|
for msgs in messages: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenized_messages = [] |
|
|
if not suppress_bos_token and self.bos_token is not None: |
|
|
tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long)) |
|
|
insertion_points = [] |
|
|
token_count = 0 |
|
|
for msg in msgs: |
|
|
token_count = self._process_content( |
|
|
msg["content"], |
|
|
role=msg["role"], |
|
|
tokenized_messages=tokenized_messages, |
|
|
insertion_points=insertion_points, |
|
|
image_list=image_list, |
|
|
token_count=token_count, |
|
|
**kwargs, |
|
|
) |
|
|
tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long)) |
|
|
image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long)) |
|
|
|
|
|
if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant": |
|
|
|
|
|
end_token_len = len(self.asst_end_tokens) |
|
|
tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len] |
|
|
if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user": |
|
|
|
|
|
end_token_len = len(self.asst_end_tokens) |
|
|
tokenized_messages_batch[-1] = torch.cat( |
|
|
[ |
|
|
tokenized_messages_batch[-1], |
|
|
torch.Tensor(self.asst_start_tokens).to(torch.long), |
|
|
] |
|
|
) |
|
|
|
|
|
input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages( |
|
|
tokenized_messages_batch, image_insertion_points_batch |
|
|
) |
|
|
|
|
|
if image_list: |
|
|
assert sum(img is None for img in image_list) % len(image_list) == 0, ( |
|
|
"Either all or no image must be None." |
|
|
) |
|
|
pixel_values: None | torch.Tensor | list[torch.Tensor] |
|
|
if image_list[0] is None: |
|
|
pixel_values = None |
|
|
else: |
|
|
pixel_values = cast(list[torch.Tensor], image_list) |
|
|
return ProcessorOutput( |
|
|
input_ids=input_ids, |
|
|
image_embeds_insertion_points=image_embeds_insertion_points, |
|
|
attention_mask=attention_mask, |
|
|
pixel_values=pixel_values, |
|
|
) |
|
|
|