jina-clip-implementation / processing_clip.py
gmastrapas's picture
feat: initial commit
56fe6da
raw
history blame
2.07 kB
# coding=utf-8
#
# Code mainly copied from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py
# and adjusted for Jina CLIP
from typing import Tuple, Union
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import ImageInput, make_list_of_images
from transformers.models.clip import CLIPProcessor
from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform
""" Jina CLIP processor implementation """
class JinaCLIPProcessor(CLIPProcessor):
image_processor_class = 'JinaCLIPImageProcessor'
tokenizer_class = 'CLIPTokenizer'
""" Jina CLIP image processor implementation """
class JinaCLIPImageProcessor(BaseImageProcessor):
model_input_names = ['pixel_values']
def __init__(
self,
size: Union[int, Tuple[int, int]] = 224,
mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN,
std: Union[float, Tuple[float]] = OPENAI_DATASET_STD,
resize_mode: str = 'shortest',
interpolation: str = 'bicubic',
fill_color: int = 0,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.size = size
self.mean = mean
self.std = std
self.resize_mode = resize_mode
self.interpolation = interpolation
self.fill_color = fill_color
self.transform = image_transform(
image_size=size,
is_train=False,
mean=mean,
std=std,
resize_mode=resize_mode,
interpolation=interpolation,
fill_color=fill_color,
aug_cfg=None,
)
def to_dict(self):
output = super().to_dict()
output.pop('transform')
return output
def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
images = make_list_of_images(images)
out = torch.stack([self.transform(image) for image in images], dim=0)
return BatchFeature(data={'pixel_values': out})