|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
|
|
import torch |
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
from transformers.image_utils import ImageInput, make_list_of_images |
|
from transformers.models.clip import CLIPProcessor |
|
|
|
from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform |
|
|
|
""" Jina CLIP processor implementation """ |
|
|
|
|
|
class JinaCLIPProcessor(CLIPProcessor): |
|
image_processor_class = 'JinaCLIPImageProcessor' |
|
tokenizer_class = 'CLIPTokenizer' |
|
|
|
|
|
""" Jina CLIP image processor implementation """ |
|
|
|
|
|
class JinaCLIPImageProcessor(BaseImageProcessor): |
|
model_input_names = ['pixel_values'] |
|
|
|
def __init__( |
|
self, |
|
size: Union[int, Tuple[int, int]] = 224, |
|
mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN, |
|
std: Union[float, Tuple[float]] = OPENAI_DATASET_STD, |
|
resize_mode: str = 'shortest', |
|
interpolation: str = 'bicubic', |
|
fill_color: int = 0, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.size = size |
|
self.mean = mean |
|
self.std = std |
|
self.resize_mode = resize_mode |
|
self.interpolation = interpolation |
|
self.fill_color = fill_color |
|
self.transform = image_transform( |
|
image_size=size, |
|
is_train=False, |
|
mean=mean, |
|
std=std, |
|
resize_mode=resize_mode, |
|
interpolation=interpolation, |
|
fill_color=fill_color, |
|
aug_cfg=None, |
|
) |
|
|
|
def to_dict(self): |
|
output = super().to_dict() |
|
output.pop('transform') |
|
return output |
|
|
|
def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature: |
|
images = make_list_of_images(images) |
|
out = torch.stack([self.transform(image) for image in images], dim=0) |
|
return BatchFeature(data={'pixel_values': out}) |
|
|