|  | import math | 
					
						
						|  | from typing import ClassVar, List, Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from PIL import Image, ImageOps | 
					
						
						|  | from transformers import BatchFeature, LlavaNextProcessor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def round_by_factor(number: float, factor: int) -> int: | 
					
						
						|  | """Returns the closest integer to 'number' that is divisible by 'factor'.""" | 
					
						
						|  | return round(number / factor) * factor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def ceil_by_factor(number: float, factor: int) -> int: | 
					
						
						|  | """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" | 
					
						
						|  | return math.ceil(number / factor) * factor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def floor_by_factor(number: float, factor: int) -> int: | 
					
						
						|  | """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" | 
					
						
						|  | return math.floor(number / factor) * factor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ColGraniteVisionProcessor(LlavaNextProcessor): | 
					
						
						|  | """ | 
					
						
						|  | Processor for ColPali. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | visual_prompt_prefix: ClassVar[str] = "<|user|>\n<image>\nDescribe the image.\n" | 
					
						
						|  | system_message: ClassVar[ | 
					
						
						|  | str] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." | 
					
						
						|  | query_prefix: ClassVar[str] = "Query: " | 
					
						
						|  | query_start: ClassVar[str] = "<|user|>\n" | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, *args, **kwargs): | 
					
						
						|  | super().__init__(*args, **kwargs) | 
					
						
						|  | self.factor = 14 | 
					
						
						|  | self.min_size = 384 | 
					
						
						|  | self.max_size = 384 * 2 | 
					
						
						|  | self.suffix_len = 10 | 
					
						
						|  | self.patch_size = 14 | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def query_augmentation_token(self) -> str: | 
					
						
						|  | """ | 
					
						
						|  | Return the query augmentation token. | 
					
						
						|  | Query augmentation buffers are used as reasoning buffers during inference. | 
					
						
						|  | """ | 
					
						
						|  | return self.tokenizer.pad_token | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def smart_resize_helper( | 
					
						
						|  | width: int, | 
					
						
						|  | height: int, | 
					
						
						|  | factor: int, | 
					
						
						|  | min_size: int, | 
					
						
						|  | max_size: int | 
					
						
						|  | ) -> Tuple[int, int]: | 
					
						
						|  | """ | 
					
						
						|  | Returns the resized image dimensions such that: | 
					
						
						|  | 1. The smaller dimension is set to 'min_size'. | 
					
						
						|  | 2. The larger dimension is scaled proportionally to maintain aspect ratio. | 
					
						
						|  | 3. If the larger dimension exceeds 'max_size', it is clipped to 'max_size', | 
					
						
						|  | and the smaller dimension is adjusted accordingly to maintain aspect ratio. | 
					
						
						|  | 4. Both dimensions are divisible by 'factor'. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if height < width: | 
					
						
						|  | scale_factor = min_size / height | 
					
						
						|  | else: | 
					
						
						|  | scale_factor = min_size / width | 
					
						
						|  |  | 
					
						
						|  | new_width = round(width * scale_factor) | 
					
						
						|  | new_height = round(height * scale_factor) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if max(new_width, new_height) > max_size: | 
					
						
						|  | clip_factor = max_size / max(new_width, new_height) | 
					
						
						|  | new_width = round(new_width * clip_factor) | 
					
						
						|  | new_height = round(new_height * clip_factor) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return new_width, new_height | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def pad_image_center(image: Image.Image, | 
					
						
						|  | target_width: int, | 
					
						
						|  | target_height: int, | 
					
						
						|  | fill_color=(0, 0, 0)) -> Image.Image: | 
					
						
						|  | """ | 
					
						
						|  | Pads the given image to be centered within the target dimensions. | 
					
						
						|  |  | 
					
						
						|  | :param image: PIL Image to be padded. | 
					
						
						|  | :param target_width: The desired width after padding. | 
					
						
						|  | :param target_height: The desired height after padding. | 
					
						
						|  | :param fill_color: Background color (default is black). | 
					
						
						|  | :return: Padded image with centered content. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img_width, img_height = image.size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pad_left = (target_width - img_width) // 2 | 
					
						
						|  | pad_top = (target_height - img_height) // 2 | 
					
						
						|  | pad_right = target_width - img_width - pad_left | 
					
						
						|  | pad_bottom = target_height - img_height - pad_top | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | padded_image = ImageOps.expand(image, (pad_left, pad_top, pad_right, pad_bottom), fill_color).convert("RGB") | 
					
						
						|  |  | 
					
						
						|  | return padded_image | 
					
						
						|  |  | 
					
						
						|  | def smart_resize(self, image: Image.Image) -> Image.Image: | 
					
						
						|  | """ | 
					
						
						|  | Resize and convert the image to the required format. | 
					
						
						|  | """ | 
					
						
						|  | image_size = image.size | 
					
						
						|  | resized_height, resized_width = self.smart_resize_helper( | 
					
						
						|  | width=image_size[0], | 
					
						
						|  | height=image_size[1], | 
					
						
						|  | factor=self.factor, | 
					
						
						|  | min_size=self.min_size, | 
					
						
						|  | max_size=self.max_size | 
					
						
						|  | ) | 
					
						
						|  | return image.convert("RGB").resize((resized_width, resized_height)) | 
					
						
						|  |  | 
					
						
						|  | def smart_resize_and_pad(self, image: Image.Image) -> Image.Image: | 
					
						
						|  | """ | 
					
						
						|  | Resize and pad the image to the required format. | 
					
						
						|  | """ | 
					
						
						|  | return self.resize_and_pad_centered_to_long_side( | 
					
						
						|  | image=image, | 
					
						
						|  | factor=self.factor, | 
					
						
						|  | min_size=self.min_size, | 
					
						
						|  | max_size=self.max_size, | 
					
						
						|  | fill_color=0 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def resize_and_pad_centered_to_long_side( | 
					
						
						|  | self, | 
					
						
						|  | image: Image.Image, | 
					
						
						|  | factor: int, | 
					
						
						|  | min_size: int, | 
					
						
						|  | max_size: int, | 
					
						
						|  | fill_color=0 | 
					
						
						|  | ) -> Image.Image: | 
					
						
						|  | """ | 
					
						
						|  | Resizes and pads an image such that: | 
					
						
						|  | - The long side is set to `max_size`. | 
					
						
						|  | - The short side is scaled proportionally but not below `min_size`. | 
					
						
						|  | - The image is centered within the final padded area. | 
					
						
						|  |  | 
					
						
						|  | :param image: PIL Image | 
					
						
						|  | :param factor: Factor to make dimensions divisible by | 
					
						
						|  | :param min_size: Minimum allowed size for the short side | 
					
						
						|  | :param max_size: Target size for the long side | 
					
						
						|  | :param fill_color: Background padding color (default black) | 
					
						
						|  | :return: Resized and padded image | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | width, height = image.size | 
					
						
						|  |  | 
					
						
						|  | if min_size == -1 or max_size == -1: | 
					
						
						|  | return image.convert("RGB") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if width > height: | 
					
						
						|  | scale_factor = max_size / width | 
					
						
						|  | target_width = max_size | 
					
						
						|  | max_scale_factor = max(min_size / height, scale_factor) | 
					
						
						|  | target_height = round(height * max_scale_factor) | 
					
						
						|  | else: | 
					
						
						|  | scale_factor = max_size / height | 
					
						
						|  | target_height = max_size | 
					
						
						|  | max_scale_factor = max(min_size / width, scale_factor) | 
					
						
						|  | target_width = round(width * max_scale_factor) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | resized_image = image.resize((target_width, target_height), Image.LANCZOS) | 
					
						
						|  | final_image =resized_image.convert("RGB") | 
					
						
						|  |  | 
					
						
						|  | return final_image | 
					
						
						|  |  | 
					
						
						|  | def resize_and_pad_centered(self, | 
					
						
						|  | image: Image.Image, | 
					
						
						|  | factor: int, | 
					
						
						|  | min_size: int, | 
					
						
						|  | max_size: int, | 
					
						
						|  | fill_color=0 | 
					
						
						|  | ) -> Image.Image: | 
					
						
						|  | """ | 
					
						
						|  | Resizes and pads an image such that: | 
					
						
						|  | - The short side is set to `min_size`. | 
					
						
						|  | - The long side is scaled proportionally but clipped to `max_size`. | 
					
						
						|  | - The image is centered within the final padded area. | 
					
						
						|  |  | 
					
						
						|  | :param image: PIL Image | 
					
						
						|  | :param factor: Factor to make dimensions divisible by | 
					
						
						|  | :param min_size: Minimum size for the short side | 
					
						
						|  | :param max_size: Maximum allowed size for the long side | 
					
						
						|  | :param fill_color: Background padding color (default black) | 
					
						
						|  | :return: Resized and padded image | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | width, height = image.size | 
					
						
						|  |  | 
					
						
						|  | if min_size == -1 or max_size == -1: | 
					
						
						|  | return image.convert("RGB") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if width < height: | 
					
						
						|  | scale_factor = min_size / width | 
					
						
						|  | target_width = min_size | 
					
						
						|  | max_scale_factor = min(max_size / height, scale_factor) | 
					
						
						|  | target_height = round(height * max_scale_factor) | 
					
						
						|  | else: | 
					
						
						|  | scale_factor = min_size / height | 
					
						
						|  | target_height = min_size | 
					
						
						|  | max_scale_factor = min(max_size / width, scale_factor) | 
					
						
						|  | target_width = round(width * max_scale_factor) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | resized_image = image.resize((target_width, target_height), Image.LANCZOS) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if width < height: | 
					
						
						|  | final_width, final_height = min_size, max_size | 
					
						
						|  | else: | 
					
						
						|  | final_width, final_height = max_size, min_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pad_left = (final_width - target_width) // 2 | 
					
						
						|  | pad_top = (final_height - target_height) // 2 | 
					
						
						|  | pad_right = final_width - target_width - pad_left | 
					
						
						|  | pad_bottom = final_height - target_height - pad_top | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | final_image = resized_image.convert("RGB") | 
					
						
						|  |  | 
					
						
						|  | return final_image | 
					
						
						|  |  | 
					
						
						|  | def format_data(self, question, image): | 
					
						
						|  | return [ | 
					
						
						|  | { | 
					
						
						|  | "role": "system", | 
					
						
						|  | "content": [{"type": "text", "text": self.system_message}], | 
					
						
						|  | }, | 
					
						
						|  | { | 
					
						
						|  | "role": "user", | 
					
						
						|  | "content": [ | 
					
						
						|  | { | 
					
						
						|  | "type": "image", | 
					
						
						|  | "image": image, | 
					
						
						|  | }, | 
					
						
						|  | { | 
					
						
						|  | "type": "text", | 
					
						
						|  | "text": question, | 
					
						
						|  | }, | 
					
						
						|  | ], | 
					
						
						|  | } | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | def format_data_wo_role(self, question, image=None): | 
					
						
						|  | return [ | 
					
						
						|  | { | 
					
						
						|  | "role": "user", | 
					
						
						|  | "content": [ | 
					
						
						|  | { | 
					
						
						|  | "type": "image", | 
					
						
						|  | "image": image, | 
					
						
						|  | }, | 
					
						
						|  | { | 
					
						
						|  | "type": "text", | 
					
						
						|  | "text": question, | 
					
						
						|  | }, | 
					
						
						|  | ], | 
					
						
						|  | } | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | def process_images( | 
					
						
						|  | self, | 
					
						
						|  | images: List[Image.Image], | 
					
						
						|  | ) -> BatchFeature: | 
					
						
						|  | """ | 
					
						
						|  | Process images for ColPali. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | texts_doc = [self.visual_prompt_prefix for _ in images] | 
					
						
						|  | images = [self.smart_resize_and_pad(image) for image in images] | 
					
						
						|  |  | 
					
						
						|  | batch_doc = self( | 
					
						
						|  | text=texts_doc, | 
					
						
						|  | images=images, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | padding="longest", | 
					
						
						|  | ) | 
					
						
						|  | return batch_doc | 
					
						
						|  |  | 
					
						
						|  | def process_queries(self, queries, max_length=2048, suffix=None): | 
					
						
						|  | if suffix is None: | 
					
						
						|  | suffix = self.query_augmentation_token * self.suffix_len | 
					
						
						|  |  | 
					
						
						|  | processed = [] | 
					
						
						|  | for q in queries: | 
					
						
						|  | q = self.query_start + self.query_prefix + q | 
					
						
						|  |  | 
					
						
						|  | if len(q) + len(suffix) > max_length: | 
					
						
						|  | q = q[: max_length - len(suffix) - 1] | 
					
						
						|  | q += suffix + "\n" | 
					
						
						|  | processed.append(q) | 
					
						
						|  |  | 
					
						
						|  | return self( | 
					
						
						|  | text=processed, | 
					
						
						|  | images=None, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | padding="longest", | 
					
						
						|  | truncation=True, | 
					
						
						|  | max_length=max_length, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def score( | 
					
						
						|  | self, | 
					
						
						|  | qs: List[torch.Tensor], | 
					
						
						|  | ps: List[torch.Tensor], | 
					
						
						|  | device: Optional[Union[str, torch.device]] = None, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. | 
					
						
						|  | """ | 
					
						
						|  | return self.score_multi_vector(qs, ps, device=device, **kwargs) | 
					
						
						|  |  | 
					
						
						|  | def get_n_patches( | 
					
						
						|  | self, | 
					
						
						|  | image_size: Tuple[int, int], | 
					
						
						|  | patch_size: int, | 
					
						
						|  | ) -> Tuple[int, int]: | 
					
						
						|  | n_patches_x = self.image_processor.size["width"] // patch_size | 
					
						
						|  | n_patches_y = self.image_processor.size["height"] // patch_size | 
					
						
						|  |  | 
					
						
						|  | return n_patches_x, n_patches_y | 
					
						
						|  |  | 
					
						
						|  | def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: | 
					
						
						|  | return batch_images.input_ids == self.image_token_id | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def score_single_vector( | 
					
						
						|  | qs: List[torch.Tensor], | 
					
						
						|  | ps: List[torch.Tensor], | 
					
						
						|  | device: Optional[Union[str, torch.device]] = None, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Compute the dot product score for the given single-vector query and passage embeddings. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if len(qs) == 0: | 
					
						
						|  | raise ValueError("No queries provided") | 
					
						
						|  | if len(ps) == 0: | 
					
						
						|  | raise ValueError("No passages provided") | 
					
						
						|  |  | 
					
						
						|  | qs_stacked = torch.stack(qs).to(device) | 
					
						
						|  | ps_stacked = torch.stack(ps).to(device) | 
					
						
						|  |  | 
					
						
						|  | scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) | 
					
						
						|  | assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" | 
					
						
						|  |  | 
					
						
						|  | scores = scores.to(torch.float32) | 
					
						
						|  | return scores | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def score_multi_vector( | 
					
						
						|  | qs: Union[torch.Tensor, List[torch.Tensor]], | 
					
						
						|  | ps: Union[torch.Tensor, List[torch.Tensor]], | 
					
						
						|  | batch_size: int = 128, | 
					
						
						|  | device: Optional[Union[str, torch.device]] = None, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector | 
					
						
						|  | query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the | 
					
						
						|  | image of a document page. | 
					
						
						|  |  | 
					
						
						|  | Because the embedding tensors are multi-vector and can thus have different shapes, they | 
					
						
						|  | should be fed as: | 
					
						
						|  | (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) | 
					
						
						|  | (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually | 
					
						
						|  | obtained by padding the list of tensors. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings. | 
					
						
						|  | ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings. | 
					
						
						|  | batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. | 
					
						
						|  | device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not | 
					
						
						|  | provided, uses `get_torch_device("auto")`. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score | 
					
						
						|  | tensor is saved on the "cpu" device. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if len(qs) == 0: | 
					
						
						|  | raise ValueError("No queries provided") | 
					
						
						|  | if len(ps) == 0: | 
					
						
						|  | raise ValueError("No passages provided") | 
					
						
						|  |  | 
					
						
						|  | scores_list: List[torch.Tensor] = [] | 
					
						
						|  |  | 
					
						
						|  | for i in range(0, len(qs), batch_size): | 
					
						
						|  | scores_batch = [] | 
					
						
						|  | qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i: i + batch_size], batch_first=True, padding_value=0).to( | 
					
						
						|  | device | 
					
						
						|  | ) | 
					
						
						|  | for j in range(0, len(ps), batch_size): | 
					
						
						|  | ps_batch = torch.nn.utils.rnn.pad_sequence( | 
					
						
						|  | ps[j: j + batch_size], batch_first=True, padding_value=0 | 
					
						
						|  | ).to(device) | 
					
						
						|  | scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) | 
					
						
						|  | scores_batch = torch.cat(scores_batch, dim=1).cpu() | 
					
						
						|  | scores_list.append(scores_batch) | 
					
						
						|  |  | 
					
						
						|  | scores = torch.cat(scores_list, dim=0) | 
					
						
						|  | assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" | 
					
						
						|  |  | 
					
						
						|  | scores = scores.to(torch.float32) | 
					
						
						|  | return scores |