|
from functools import partial |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from transformers.processing_utils import ProcessorMixin |
|
from transformers.image_processing_utils import BaseImageProcessor |
|
from transformers import AutoTokenizer, AutoConfig |
|
from transformers import BatchFeature |
|
|
|
from PIL import Image |
|
from torchvision.transforms import ( |
|
Compose, |
|
Normalize, |
|
Resize, |
|
ToTensor |
|
) |
|
|
|
|
|
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
|
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
def convert_to_rgb(x): |
|
return x.convert("RGB") |
|
|
|
|
|
def expand2square(image, background_color): |
|
width, height = image.size |
|
if width == height: |
|
return image |
|
elif width > height: |
|
result = Image.new(image.mode, (width, width), background_color) |
|
result.paste(image, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(image.mode, (height, height), background_color) |
|
result.paste(image, ((height - width) // 2, 0)) |
|
return result |
|
|
|
|
|
class ImageProcessor(BaseImageProcessor): |
|
def __init__( |
|
self, |
|
image_size: int, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.transform = Compose( |
|
[ |
|
convert_to_rgb, |
|
partial( |
|
expand2square, |
|
background_color=tuple(int(255 * v) for v in IMAGENET_MEAN) |
|
), |
|
Resize(image_size), |
|
ToTensor(), |
|
Normalize( |
|
mean=IMAGENET_MEAN, |
|
std=IMAGENET_STD, |
|
), |
|
] |
|
) |
|
|
|
def preprocess( |
|
self, |
|
image: Image |
|
): |
|
return self.transform(image) |
|
|
|
def __repr__(self): |
|
return repr(self.transform) |
|
|
|
|
|
class VLMProcessor(ProcessorMixin): |
|
def __init__(self, config): |
|
self.config = config |
|
self.image_size = config.image_size |
|
|
|
self.feature_extractor = ImageProcessor(self.image_size) |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
config.text_decoder_name_or_path, additional_special_tokens=["<image>"] |
|
) |
|
self.tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
|
self.num_image_latents = config.num_image_latents |
|
|
|
|
|
def __call__( |
|
self, text=None, images=None, **kwargs |
|
): |
|
if text is not None: |
|
if isinstance(text, str): |
|
text = [text] |
|
|
|
tokenized_texts = [] |
|
for t in text: |
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": f" <image> {t}"}, |
|
] |
|
tokenized_prompt = self.tokenizer.apply_chat_template( |
|
messages, add_generation_prompt=True, return_tensors="pt" |
|
) |
|
|
|
tokenized_texts.append(tokenized_prompt) |
|
|
|
max_len = max(len(t[0]) for t in tokenized_texts) |
|
input_ids = torch.full( |
|
(len(tokenized_texts), max_len), |
|
fill_value=self.tokenizer.pad_token_id, |
|
dtype=torch.int64, |
|
) |
|
attention_mask = torch.full( |
|
(len(tokenized_texts), max_len), fill_value=0, dtype=torch.int64 |
|
) |
|
|
|
for i, tokens in enumerate(tokenized_texts): |
|
input_ids[i, -len(tokens[0]) :] = tokens[0] |
|
attention_mask[i, -len(tokens[0]) :] = 1 |
|
|
|
attention_mask = F.pad( |
|
attention_mask, pad=(0, self.num_image_latents - 1), value=1 |
|
) |
|
|
|
encoding = BatchFeature( |
|
data={"input_ids": input_ids, "attention_mask": attention_mask} |
|
) |
|
|
|
if images is not None: |
|
if isinstance(images, (list, tuple)): |
|
image_features = torch.empty( |
|
(len(images), 3, self.image_size , self.image_size), |
|
dtype=torch.float32, |
|
) |
|
|
|
for i, image in enumerate(images): |
|
image_features[i] = self.feature_extractor(image) |
|
|
|
else: |
|
image_features = self.image_processor(images).unsqueeze(0) |
|
|
|
if text is not None and images is not None: |
|
encoding["images"] = image_features |
|
return encoding |
|
|
|
elif text is not None: |
|
return encoding |
|
|
|
else: |
|
return BatchFeature( |
|
data={ |
|
"images": image_features, |
|
}, |
|
tensor_type=return_tensors, |
|
) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path, |
|
trust_remote_code=False, |
|
**kwargs |
|
): |
|
config = AutoConfig.from_pretrained( |
|
pretrained_model_name_or_path, |
|
trust_remote_code=trust_remote_code |
|
) |
|
return cls(config) |
|
|