# coding=utf-8 # # Code mainly copied from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py # and adjusted for Jina CLIP from typing import Tuple, Union import torch from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_utils import ImageInput, make_list_of_images from transformers.models.clip import CLIPProcessor from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform """ Jina CLIP processor implementation """ class JinaCLIPProcessor(CLIPProcessor): image_processor_class = 'JinaCLIPImageProcessor' tokenizer_class = 'CLIPTokenizer' """ Jina CLIP image processor implementation """ class JinaCLIPImageProcessor(BaseImageProcessor): model_input_names = ['pixel_values'] def __init__( self, size: Union[int, Tuple[int, int]] = 224, mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN, std: Union[float, Tuple[float]] = OPENAI_DATASET_STD, resize_mode: str = 'shortest', interpolation: str = 'bicubic', fill_color: int = 0, **kwargs, ) -> None: super().__init__(**kwargs) self.size = size self.mean = mean self.std = std self.resize_mode = resize_mode self.interpolation = interpolation self.fill_color = fill_color self.transform = image_transform( image_size=size, is_train=False, mean=mean, std=std, resize_mode=resize_mode, interpolation=interpolation, fill_color=fill_color, aug_cfg=None, ) def to_dict(self): output = super().to_dict() output.pop('transform') return output def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature: images = make_list_of_images(images) out = torch.stack([self.transform(image) for image in images], dim=0) return BatchFeature(data={'pixel_values': out})