Spaces:
Runtime error
Runtime error
| import torch | |
| from PIL import Image | |
| from typing import Tuple, List | |
| import numpy as np | |
| from transformers import GemmaTokenizerFast, BatchFeature | |
| import json | |
| import os | |
| def preprocess_imgs(imgs: List[Image.Image], | |
| img_size: Tuple[int, int], | |
| rescale: float, | |
| mean: Tuple[float, float, float], | |
| std: Tuple[float, float, float]): | |
| def normalize(img, mean, std): | |
| img = (img - np.array(mean, dtype=img.dtype)) / np.array(std, dtype=img.dtype) | |
| return img | |
| resized_imgs = [np.array(img.resize((img_size[0], img_size[1]), resample=3)) for img in imgs] | |
| rescaled_imgs = [(img * rescale).astype(np.float32) for img in resized_imgs] | |
| normalized_imgs = [normalize(img, mean, std) for img in rescaled_imgs] | |
| transposed_imgs = [img.transpose(2, 0, 1) for img in normalized_imgs] | |
| tensor_imgs = torch.tensor(np.stack(transposed_imgs, axis=0), dtype=torch.float32) | |
| return tensor_imgs | |
| def preprocess_prompts(prompt, image_token, max_num_image_token, bos_token): | |
| return f"{image_token * max_num_image_token}{bos_token}{prompt}\n" | |
| class PaliGemmaProcessor: | |
| IMAGE_TOKEN = "<image>" | |
| def __init__(self, | |
| tokenizer: GemmaTokenizerFast) -> None: | |
| additional_special_tokens = {"additional_special_tokens": [self.IMAGE_TOKEN]} | |
| tokenizer.add_special_tokens(additional_special_tokens) | |
| EXTRA_TOKENS = [ | |
| f"<loc{i:04d}>" for i in range(1024) | |
| ] # These tokens are used for object detection (bounding boxes) | |
| EXTRA_TOKENS += [ | |
| f"<seg{i:03d}>" for i in range(128) | |
| ] | |
| tokenizer.add_tokens(EXTRA_TOKENS) | |
| tokenizer.add_bos_token = False | |
| tokenizer.add_eos_token = False | |
| self.tokenizer = tokenizer | |
| def from_pretrained(self, pretrained_dir): | |
| with open(os.path.join(pretrained_dir, "preprocessor_config.json"), "r") as f: | |
| config = json.loads(f.read()) | |
| self.image_seq_length = config['image_seq_length'] | |
| self.image_mean = config['image_mean'] | |
| self.image_std = config['image_std'] | |
| self.resample = config['resample'] | |
| self.rescale_factor = config['rescale_factor'] | |
| self.size = (config['size']['height'], config['size']['width']) | |
| return self | |
| def __call__(self, | |
| imgs: List[Image.Image], | |
| prompts: List[str], | |
| padding: str = "longest", | |
| truncation: bool = True, | |
| max_length: int = None): | |
| processed_imgs = preprocess_imgs(imgs, | |
| img_size=self.size, | |
| rescale=self.rescale_factor, | |
| mean=self.image_mean, | |
| std=self.image_mean) | |
| processed_prompts = [preprocess_prompts(prompt, | |
| image_token=self.IMAGE_TOKEN, | |
| max_num_image_token=self.image_seq_length, | |
| bos_token=self.tokenizer.bos_token) for prompt in prompts] | |
| model_inputs = self.tokenizer(processed_prompts, | |
| return_tensors='pt', | |
| padding=padding, | |
| truncation=truncation, | |
| max_length=max_length) | |
| return {**model_inputs, "pixel_values": processed_imgs} | |