from typing import Optional, Union, Dict, Any import torch import math import PIL.Image import PIL.ImageSequence import numpy as np import PIL from PIL import Image from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers import AutoImageProcessor from transformers.image_transforms import to_channel_dimension_format from transformers.image_utils import ( ImageInput, make_list_of_images, valid_images, is_torch_tensor, to_numpy_array, infer_channel_dimension_format, ChannelDimension ) def recursive_converter(converter, value): if isinstance(value, list): new_value = [] for v in value: new_value += [recursive_converter(converter, v)] return new_value else: return converter(value) class MiniCPMVBatchFeature(BatchFeature): r""" Extend from BatchFeature for supporting various image size """ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): super().__init__(data) self.convert_to_tensors(tensor_type=tensor_type) def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): if tensor_type is None: return self is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) def converter(value): try: if not is_tensor(value): tensor = as_tensor(value) return tensor except: # noqa E722 if key == "overflowing_values": raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") raise ValueError( "Unable to create tensor, you should probably activate padding " "with 'padding=True' to have batched tensors with the same length." ) for key, value in self.items(): self[key] = recursive_converter(converter, value) return self def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature": requires_backends(self, ["torch"]) import torch def cast_tensor(v): # check if v is a floating point if torch.is_floating_point(v): # cast and send to device return v.to(*args, **kwargs) elif device is not None: return v.to(device=device) else: return v new_data = {} device = kwargs.get("device") # Check if the args are a device or a dtype if device is None and len(args) > 0: # device should be always the first argument arg = args[0] if is_torch_dtype(arg): # The first argument is a dtype pass elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): device = arg else: # it's something else raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` for k, v in self.items(): new_data[k] = recursive_converter(cast_tensor, v) self.data = new_data return self class MiniCPMVImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__( self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs): super().__init__(**kwargs) self.max_slice_nums = max_slice_nums self.scale_resolution = scale_resolution self.patch_size = patch_size self.image_feature_size = kwargs.pop("image_feature_size", 64) self.im_start_token = kwargs.pop("im_start", "") self.im_end_token = kwargs.pop("im_end", "") self.slice_start_token = kwargs.pop("slice_start", "") self.slice_end_token = kwargs.pop("slice_end", "") self.unk_token = kwargs.pop("unk", "") self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5])) self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5])) self.version = kwargs.pop("version", 2.0) def ensure_divide(self, length, patch_size): return max(round(length / patch_size) * patch_size, patch_size) def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False): width, height = original_size if (width * height > scale_resolution * scale_resolution) or allow_upscale: r = width / height height = int(scale_resolution / math.sqrt(r)) width = int(height * r) best_width = self.ensure_divide(width, patch_size) best_height = self.ensure_divide(height, patch_size) return (best_width, best_height) def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False): width, height = original_size grid_x, grid_y = grid refine_width = self.ensure_divide(width, grid_x) refine_height = self.ensure_divide(height, grid_y) grid_width = refine_width / grid_x grid_height = refine_height / grid_y best_grid_size = self.find_best_resize((grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale) refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) return refine_size def split_to_patches(self, image, grid): patches = [] width, height = image.size grid_x = int(width / grid[0]) grid_y = int(height / grid[1]) for i in range(0, height, grid_y): images = [] for j in range(0, width, grid_x): box = (j, i, j + grid_x, i + grid_y) patch = image.crop(box) images.append(patch) patches.append(images) return patches def slice_image( self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False ): original_size = image.size original_width, original_height = original_size log_ratio = math.log(original_width / original_height) ratio = original_width * original_height / (scale_resolution * scale_resolution) multiple = min(math.ceil(ratio), max_slice_nums) source_image = None best_grid = None patches = [] if multiple <= 1 or never_split: # dont need to slice, upsample best_size = self.find_best_resize( original_size, scale_resolution, patch_size, allow_upscale=True ) source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC) else: candidate_split_grids_nums = [] for i in [multiple - 1, multiple, multiple + 1]: if i == 1 or i > max_slice_nums: continue candidate_split_grids_nums.append(i) # source image, down-sampling and ensure divided by patch_size best_resize = self.find_best_resize(original_size, scale_resolution, patch_size) source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC) candidate_grids = [] # find best grid for split_grids_nums in candidate_split_grids_nums: m = 1 while m <= split_grids_nums: if split_grids_nums % m == 0: candidate_grids.append([m, split_grids_nums // m]) m += 1 best_grid = [1, 1] min_error = float("inf") for grid in candidate_grids: error = abs(log_ratio - math.log(grid[0] / grid[1])) if error < min_error: best_grid = grid min_error = error refine_size = self.get_refine_size( original_size, best_grid, scale_resolution, patch_size, allow_upscale=True ) refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC) patches = self.split_to_patches(refine_image, best_grid) return source_image, patches, best_grid def get_grid_placeholder(self, grid): if grid is None: return "" image_placeholder = ( self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token ) cols = grid[0] rows = grid[1] slices = [] for i in range(rows): lines = [] for j in range(cols): lines.append(image_placeholder) slices.append("".join(lines)) slice_placeholder = self.slice_start_token + "\n".join(slices) + self.slice_end_token return slice_placeholder def get_sliced_images(self, image): slice_images = [] source_image, patches, sliced_grid = self.slice_image( image, self.max_slice_nums, # default: 9 self.scale_resolution, # default: 448 self.patch_size # default: 14 ) slice_images.append(source_image) if len(patches) > 0: for i in range(len(patches)): for j in range(len(patches[0])): slice_images.append(patches[i][j]) return slice_images def get_sliced_grid(self, image_size): original_width, original_height = image_size log_ratio = math.log(original_width / original_height) ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution) multiple = min(math.ceil(ratio), self.max_slice_nums) if multiple <= 1: return None candidate_split_grids_nums = [] for i in [multiple - 1, multiple, multiple + 1]: if i == 1 or i > self.max_slice_nums: continue candidate_split_grids_nums.append(i) candidate_grids = [] for split_grids_nums in candidate_split_grids_nums: m = 1 while m <= split_grids_nums: if split_grids_nums % m == 0: candidate_grids.append([m, split_grids_nums // m]) m += 1 best_grid = [1, 1] min_error = float("inf") for grid in candidate_grids: error = abs(log_ratio - math.log(grid[0] / grid[1])) if error < min_error: best_grid = grid min_error = error return best_grid def get_slice_image_placeholder(self, image_size): grid = self.get_sliced_grid(image_size=image_size) return ( self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token ) + self.get_grid_placeholder(grid=grid) def to_pil_image(self, image, rescale=None) -> PIL.Image.Image: """ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if needed. Args: image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): The image to convert to the PIL Image format. rescale (`bool`, *optional*): Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default to `True` if the image type is a floating type, `False` otherwise. """ if isinstance(image, PIL.Image.Image): return image if is_torch_tensor(image): image = image.numpy() if isinstance(image, np.ndarray): if rescale is None: # rescale default to the array being of floating type. rescale = isinstance(image.flat[0], np.floating) # If the channel as been moved to first dim, we put it back at the end. if image.ndim == 3 and image.shape[0] in [1, 3]: image = image.transpose(1, 2, 0) if rescale: image = image * 255 image = image.astype(np.uint8) return PIL.Image.fromarray(image) return image def reshape_by_patch(self, image): """ :param image: shape [3, H, W] :param patch_size: :return: [3, patch_size, HW/patch_size] """ image = torch.from_numpy(image) patch_size = self.patch_size patches = torch.nn.functional.unfold( image, (patch_size, patch_size), stride=(patch_size, patch_size) ) patches = patches.reshape(image.size(0), patch_size, patch_size, -1) patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1) return patches.numpy() def preprocess(self, images, do_pad, return_tensors): batch = [] for image in images: result = preprocess2(image, do_pad, return_tensors) batch.append(result) return batch def preprocess2( self, images: ImageInput, do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5 return_tensors: Optional[Union[str, TensorType]] = None ) -> MiniCPMVBatchFeature: images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) images = [self.to_pil_image(image).convert("RGB") for image in images] input_data_format = infer_channel_dimension_format(np.array(images[0])) new_images = [] image_sizes = [image.size for image in images] tgt_sizes = [] for image in images: image_patches = self.get_sliced_images(image) image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches] image_patches = [ self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format) for image in image_patches ] image_patches = [ to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format) for image in image_patches ] for slice_image in image_patches: new_images.append(self.reshape_by_patch(slice_image)) tgt_sizes.append(np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))) if tgt_sizes: tgt_sizes = np.vstack(tgt_sizes) return MiniCPMVBatchFeature( data={"pixel_values": [new_images], "image_sizes": [image_sizes], "tgt_sizes": [tgt_sizes]}, tensor_type=return_tensors ) AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)