ChatRex-7B / preprocessing_chatrex.py
CRIS-Yang's picture
Model Initial Update 1
7b6241f verified
"""
Processor class for Molmo.
"""
from typing import Optional
import PIL
from PIL import Image
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
import re
from typing import List, Optional, Union
import numpy as np
import torch
import torchvision.transforms.functional as F
from transformers import AutoTokenizer
from transformers.image_utils import ImageInput
from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
TextKwargs)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
logger = logging.get_logger(__name__)
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN_INDEX = 0
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
# For Objects
DEFAULT_OBJECT_TOKEN = "<obj<i>>"
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>"
DEFAULT_OBJECT_INDEX = -300
# For Grounding
DEFAULT_GROUNDING_START = "<ground>"
DEFAULT_GROUNDING_END = "</ground>"
DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
DEFAULT_GROUNDING_OBJECTS_END = "</objects>"
def xyxy_to_xywh(boxes):
"""
Convert boxes from xywh to xyxy format.
Parameters:
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes.
Each box is represented as [x, y, x, y].
Returns:
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, w, h].
"""
boxes = np.array(boxes)
x_min, y_min, x_max, y_max = (
boxes[:, 0],
boxes[:, 1],
boxes[:, 2],
boxes[:, 3],
)
w = x_max - x_min
h = y_max - y_min
return np.stack([x_min, y_min, w, h], axis=1)
def xywh_to_xyxy(boxes):
"""
Convert boxes from xywh to xyxy format.
Parameters:
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes.
Each box is represented as [x, y, width, height].
Returns:
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, x_max, y_max].
"""
boxes = np.array(boxes)
x, y, width, height = (
boxes[:, 0],
boxes[:, 1],
boxes[:, 2],
boxes[:, 3],
)
x_max = x + width
y_max = y + height
return np.stack([x, y, x_max, y_max], axis=1)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def pad_boxes(gt_boxes, old_size):
old_w, old_h = old_size
gt_boxes = np.array(gt_boxes).astype(np.float32)
# Calculate the padding added
if old_w > old_h:
pad_top = (old_w - old_h) // 2
pad_bottom = old_w - old_h - pad_top
pad_left, pad_right = 0, 0
else:
pad_left = (old_h - old_w) // 2
pad_right = old_h - old_w - pad_left
pad_top, pad_bottom = 0, 0
# Adjust the boxes for padding
gt_boxes[:, 0] += pad_left # x
gt_boxes[:, 1] += pad_top # y
return gt_boxes
def resize_boxes(gt_boxes, old_size, new_size):
old_w, old_h = old_size
new_h, new_w = new_size
gt_boxes = np.array(gt_boxes).astype(np.float32)
# Calculate scale factors
scale_x = new_w / max(old_w, old_h)
scale_y = new_h / max(old_w, old_h)
# Resize the boxes
gt_boxes[:, 0] *= scale_x # x
gt_boxes[:, 1] *= scale_y # y
gt_boxes[:, 2] *= scale_x # w
gt_boxes[:, 3] *= scale_y # h
return gt_boxes
def split_special_strings(input_string: str, special_strings: list[str] = None):
"""Split the input string into a list of strings, keeping the special strings.
Args:
input_string (str): The input string to split.
Example:
input_string = "<image>\n<obj0><objfeat><obj1><objfeat>\n I am happy today."
output = ['<image>', '\n<obj0>', '<objfeat>', '<obj1>', '<objfeat>', '\n I am happy today.']
Returns:
list: A list of strings, with the special strings separated from the rest of the input string.
"""
# Create a regex pattern to match the special strings
pattern = "|".join(map(re.escape, special_strings))
# Split the input string using the pattern, keeping the special strings in the result
split_list = re.split(f"({pattern})", input_string)
# Remove empty strings from the list
split_list = [s for s in split_list if s]
return split_list
def tokenizer_image_object_token(prompt, tokenizer):
bos_token_id = tokenizer.bos_token_id
split_tokens = [DEFAULT_IMAGE_TOKEN, DEFAULT_OBJECT_FEATURE_TOKEN]
chunks = split_special_strings(prompt, split_tokens)
input_encode = [bos_token_id]
for chunk in chunks:
if chunk == DEFAULT_IMAGE_TOKEN:
input_encode.append(IMAGE_TOKEN_INDEX)
elif chunk == DEFAULT_OBJECT_FEATURE_TOKEN:
input_encode.append(DEFAULT_OBJECT_INDEX)
else:
input_encode.extend(tokenizer.encode(chunk, add_special_tokens=False))
return input_encode
class ChatRexProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor = None, tokenizer : AutoTokenizer = None, **kwargs):
# self.image_processor = image_processor
# self.tokenizer = tokenizer
super().__init__(image_processor, tokenizer)
self._special_tokens = None
self.template = dict(
SYSTEM=('A chat between a curious user and an artificial '
'intelligence assistant. The assistant gives '
'helpful, detailed, and polite answers to the '
'user\'s questions. {system}\n '),
INSTRUCTION=('USER: {input} ASSISTANT:'),
SEP='\n')
def process(
self,
image: Union[str, Image.Image],
bbox: List[List[int]],
question: str,
):
"""Prepare input data for inference.
Args:
image (Union[str, Image.Image]): The image to process.
bbox (List[List[int]]): A list of bounding boxes for the image. Each bounding box should
be in order of [x, y, x , y].
question (str): The question to ask about the image.
"""
data_dict = {}
# step1 load image
if type(image) == str:
image = Image.open(image).convert("RGB")
ori_w, ori_h = F.get_image_size(image)
image = expand2square(
image,
tuple(int(x * 255) for x in self.image_processor.image_mean),
)
pad_w, pad_h = F.get_image_size(image)
image_aux = self.image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
resize_h, resize_w = image_aux.shape[-2:]
data_dict["pixel_values_aux"] = image_aux.unsqueeze(0)
image = image_aux.clone()
image = torch.nn.functional.interpolate(
image[None],
size=[336, 336],
mode="bilinear",
align_corners=False,
)[0]
data_dict["pixel_values"] = image.unsqueeze(0)
# step2 load boxes
bbox= xyxy_to_xywh(bbox)
bbox = pad_boxes(bbox, (ori_w, ori_h))
bbox = resize_boxes(bbox, (pad_w, pad_h), (resize_h, resize_w))
data_dict["gt_boxes"] = torch.tensor(xywh_to_xyxy(bbox)).unsqueeze(0)
# step3 prepare question
total_num_boxes = len(bbox)
obj_tokens = [
DEFAULT_OBJECT_TOKEN.replace("<i>", str(i)) for i in range(total_num_boxes)
]
obj_tokens = (
DEFAULT_OBJECT_FEATURE_TOKEN.join(obj_tokens) + DEFAULT_OBJECT_FEATURE_TOKEN
)
question = question.replace(DEFAULT_IMAGE_TOKEN, "")
question = DEFAULT_IMAGE_TOKEN + "\n" + obj_tokens + "\n" + question
inputs = ""
inputs += self.template["INSTRUCTION"].format(input=question, round=1)
# step4 tokenize question
input_ids = tokenizer_image_object_token(inputs, self.tokenizer)
data_dict["input_ids"] = torch.tensor(input_ids).unsqueeze(0)
return data_dict
ChatRexProcessor.register_for_auto_class()