from typing import List, Optional from PIL import Image import numpy as np import torch from src.utils.util import ( IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, add_image_tokens_to_prompt, process_images, ) from transformers import SiglipImageProcessor class ImageCraftProcessor: IMAGE_TOKEN = "" def __init__(self, tokenizer, num_image_tokens: int, image_size: int): super().__init__() self.image_seq_length = num_image_tokens self.image_size = image_size # Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]} tokenizer.add_special_tokens(tokens_to_add) EXTRA_TOKENS = [ f"" for i in range(1024) ] # These tokens are used for object detection (bounding boxes) EXTRA_TOKENS += [ f"" for i in range(128) ] # These tokens are used for object segmentation tokenizer.add_tokens(EXTRA_TOKENS) self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) # We will add the BOS and EOS tokens ourselves tokenizer.add_bos_token = False tokenizer.add_eos_token = False self.tokenizer = tokenizer def __call__( self, text: List[str], images: List[Image.Image], padding: str = "longest", truncation: bool = True, ) -> dict: assert ( len(images) == 1 and len(text) == 1 ), f"Received {len(images)} images for {len(text)} prompts." pixel_values = process_images( images, size=(self.image_size, self.image_size), resample=Image.Resampling.BICUBIC, rescale_factor=1 / 255.0, image_mean=IMAGENET_STANDARD_MEAN, image_std=IMAGENET_STANDARD_STD, ) # Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width] pixel_values = np.stack(pixel_values, axis=0) # Convert the numpy array to a PyTorch tensor pixel_values = torch.tensor(pixel_values, dtype=torch.float16) input_strings = [ add_image_tokens_to_prompt( prefix_prompt=prompt, bos_token=self.tokenizer.bos_token, image_seq_length=self.image_seq_length, image_token=self.IMAGE_TOKEN, ) for prompt in text ] # max_length += self.image_seq_length inputs = self.tokenizer( input_strings, return_tensors="pt", padding=padding, max_length=512, truncation=truncation, ) return_data = {"pixel_values": pixel_values, **inputs} return return_data