from typing import Dict, List, Optional, Union import numpy as np from transformers.image_processing_utils import BatchFeature, get_size_dict from transformers.image_transforms import ( convert_to_rgb, get_resize_output_image_size, resize, to_channel_dimension_format, ) from transformers.image_utils import ( ChannelDimension, ImageInput, PILImageResampling, get_image_size, infer_channel_dimension_format, is_scaled_image, make_list_of_images, to_numpy_array, valid_images, validate_kwargs, validate_preprocess_arguments, ) from transformers.utils import TensorType from transformers.models.clip.image_processing_clip import logger, CLIPImageProcessor from mmdet.models.utils import multi_apply class CustomLlavaImageProcessor(CLIPImageProcessor): def resize( self, image: np.ndarray, size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge resized to keep the input aspect ratio. Args: image (`np.ndarray`): Image to resize. size (`Dict[str, int]`): Size of the output image. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): Resampling filter to use when resiizing the image. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. """ default_to_square = True if "shortest_edge" in size: size = size["shortest_edge"] default_to_square = False # customization: force the largest edge to size h, w = get_image_size(image, channel_dim=input_data_format) if h > w: size = (size, int(w * size / h)) else: size = (int(h * size / w), size) elif "height" in size and "width" in size: size = (size["height"], size["width"]) else: raise ValueError( "Size must contain either 'shortest_edge' or 'height' and 'width'.") output_size = get_resize_output_image_size( image, size=size, default_to_square=default_to_square, input_data_format=input_data_format, ) return resize( image, size=output_size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs, ) def preprocess( self, images: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_center_crop: bool = None, crop_size: int = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ): do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, param_name="size", default_to_square=False) resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict( crop_size, param_name="crop_size", default_to_square=True) do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) validate_preprocess_arguments( do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_center_crop=do_center_crop, crop_size=crop_size, do_resize=do_resize, size=size, resample=resample, ) if do_convert_rgb: images = [convert_to_rgb(image) for image in images] # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] if is_scaled_image(images[0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) if input_data_format is None: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) image_sizes = [get_image_size( image, channel_dim=input_data_format) for image in images] if do_resize: images = [ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) for image in images ] # we do not apppy center crop # if do_center_crop: # images = [ # self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images # ] images, meta_datas = multi_apply(self.pad, images) if do_rescale: images = [ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) for image in images ] if do_normalize: images = [ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) for image in images ] images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images ] data = {"pixel_values": images, "image_sizes": image_sizes, "meta_datas": meta_datas} return BatchFeature(data=data, tensor_type=return_tensors) def pad(self, image): pad_value = np.array(tuple(int(x * 255) for x in self.image_mean), dtype=image.dtype) assert isinstance(image, np.ndarray) h, w, _ = image.shape size = max(h, w) new_image = np.ones((size, size, 3), dtype=image.dtype) * pad_value pad_height, pad_width = size - h, size - w before_height, before_width = pad_height // 2, pad_width // 2 after_height, after_width = pad_height - \ before_height, pad_width - before_width new_image[before_height:size-after_height, before_width:size-after_width] = image meta = dict(padding=dict(before_height=before_height, after_height=after_height, before_width=before_width, after_width=after_width), image_shape=dict(height=h, width=w), padded_shape=dict(height=size, width=size)) return new_image, meta