from transformers.image_processing_utils import ImageProcessingMixin, BatchFeature from torchvision.transforms import transforms as tf import torchvision.transforms.functional as F from PIL import Image import torch class CondViTProcessor(ImageProcessingMixin): def __init__( self, bkg_color=255, input_resolution=224, image_mean=(0.48145466, 0.4578275, 0.40821073), image_std=(0.26862954, 0.26130258, 0.27577711), **kwargs, ): super().__init__(**kwargs) self.bkg_color = bkg_color self.input_resolution = input_resolution self.image_mean = image_mean self.image_std = image_std def square_pad(self, image): max_wh = max(image.size) p_left, p_top = [(max_wh - s) // 2 for s in image.size] p_right, p_bottom = [ max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top]) ] padding = (p_left, p_top, p_right, p_bottom) return F.pad(image, padding, self.bkg_color, "constant") def process_img(self, image): img = self.square_pad(image) img = F.resize(img, self.input_resolution) img = F.to_tensor(img) img = F.normalize(img, self.image_mean, self.image_std) return img def __call__(self, images, texts=None): """ Parameters ---------- images : Union[Image.Image, List[Image.Image]] Image or list of images to process texts : Union[str, List[str]] Text or list of texts to process. Pass through, no operation is performed. Returns ------- BatchFeature pixel_values : torch.Tensor Processed image tensor (B C H W) texts : Union[str, List[str]] """ # Single Image data = {} if isinstance(images, Image.Image): data["pixel_values"] = self.process_img(images) else: data["pixel_values"] = torch.stack( [self.process_img(img) for img in images] ) if texts is not None: data["texts"] = texts return BatchFeature(data=data)