| """Image processor and decoding helpers for Yasa2.""" |
|
|
| from __future__ import annotations |
|
|
| import io |
| import math |
| from typing import List, Tuple |
|
|
| import numpy as np |
| from PIL import Image |
| from transformers import ConvNextImageProcessor |
|
|
|
|
| class Yasa2ImageProcessor(ConvNextImageProcessor): |
| """ConvNeXt image processor for Yasa2.""" |
|
|
| model_input_names = ["pixel_values"] |
|
|
| def __init__(self, *args, **kwargs): |
| """Initialize the image processor with optional tiling metadata. |
| |
| Args: |
| *args: Positional args forwarded to ConvNextImageProcessor. |
| **kwargs: Keyword args forwarded to ConvNextImageProcessor. |
| """ |
| kwargs.setdefault("size", {"shortest_edge": 512}) |
| |
| kwargs.setdefault("do_resize", True) |
| kwargs.setdefault("do_center_crop", False) |
| kwargs.setdefault("do_normalize", True) |
| |
| super().__init__(*args, **kwargs) |
| self.use_navit = kwargs.get("use_navit", False) |
| self.max_tiles_num = kwargs.get("max_tiles_num", 4) |
| self.patch_size = kwargs.get("patch_size", 14) |
| self.tiling_method = kwargs.get("tiling_method", "llava-uhd") |
|
|
|
|
| def image_rgb_decoder_pil( |
| image_bytes: bytes, skip_errors: bool = False |
| ) -> dict: |
| """Decode image bytes into a numpy RGB array. |
| |
| Args: |
| image_bytes: Raw image bytes. |
| skip_errors: Whether to return error info instead of raising. |
| |
| Returns: |
| Dict with pixel values or an error message. |
| """ |
| try: |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| pixel_values = np.array(image) |
| if pixel_values.ndim == 4: |
| raise ValueError( |
| "Image has 4 dimensions, expected 3 (possible GIF with jpg/png extension)." |
| ) |
| if pixel_values.shape[2] != 3: |
| raise ValueError( |
| f"Image has {pixel_values.shape[2]} channels, expected 3." |
| ) |
| return {"pixel_values": pixel_values} |
| except Exception as exc: |
| if not skip_errors: |
| raise |
| return {"error": str(exc)} |
|
|
|
|
| def image_rgb_decoder_pil_tiling( |
| image_bytes: bytes, |
| skip_errors: bool = False, |
| size: int = 1024, |
| grid_pinpoints: List[Tuple[int, int]] = None, |
| max_tiles_num: int = 9, |
| patch_size: int = 4, |
| tiling_method: str = "llava-next", |
| ) -> dict: |
| """Decode image bytes into tiled numpy arrays. |
| |
| Args: |
| image_bytes: Raw image bytes. |
| skip_errors: Whether to return error info instead of raising. |
| size: Base tile size. |
| grid_pinpoints: Candidate grid pinpoints. |
| max_tiles_num: Maximum number of tiles for UHD tiling. |
| patch_size: Patch size for UHD tiling. |
| tiling_method: Tiling method name. |
| |
| Returns: |
| Dict with tiled pixel values or an error message. |
| """ |
| if grid_pinpoints is None: |
| grid_pinpoints = [ |
| (2, 2), |
| (1, 2), |
| (2, 1), |
| (1, 3), |
| (3, 1), |
| (1, 4), |
| (4, 1), |
| ] |
| try: |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| if tiling_method.lower() == "llava-next": |
| images = process_anyres_image(image, size, grid_pinpoints) |
| pixel_values = np.array([np.array(img) for img in images]) |
| elif tiling_method.lower() == "llava-uhd": |
| images = process_anyres_image_uhd( |
| image, |
| max_tiles_num=max_tiles_num, |
| scale_resolution=size, |
| patch_size=patch_size, |
| never_split=False, |
| ) |
| pixel_values = [np.array(img) for img in images] |
| else: |
| raise ValueError(f"Unknown tiling method: {tiling_method}") |
|
|
| if tiling_method.lower() == "llava-next" and pixel_values.ndim != 4: |
| raise ValueError( |
| "Tiled image has unexpected dimensions (expected 4D)." |
| ) |
| if ( |
| tiling_method.lower() == "llava-next" |
| and pixel_values.shape[3] != 3 |
| ): |
| raise ValueError( |
| f"Tiled image has {pixel_values.shape[3]} channels, expected 3." |
| ) |
| if tiling_method.lower() == "llava-uhd" and pixel_values[-1].ndim != 3: |
| raise ValueError( |
| "UHD tiled image has unexpected dimensions (expected 3D)." |
| ) |
| if ( |
| tiling_method.lower() == "llava-uhd" |
| and pixel_values[-1].shape[2] != 3 |
| ): |
| raise ValueError( |
| f"UHD tiled image has {pixel_values[-1].shape[2]} channels, expected 3." |
| ) |
|
|
| return { |
| "pixel_values": pixel_values, |
| "num_tiles": len(pixel_values), |
| "img_tiling": True, |
| } |
| except Exception as exc: |
| if not skip_errors: |
| raise |
| return {"error": str(exc)} |
|
|
|
|
| def resize_and_pad_image( |
| image: Image.Image, target_resolution: Tuple[int, int] |
| ) -> Image.Image: |
| """Resize and pad an image to target resolution while preserving aspect ratio. |
| |
| Args: |
| image: Input PIL image. |
| target_resolution: Target (width, height). |
| |
| Returns: |
| Resized and padded PIL image. |
| """ |
| original_width, original_height = image.size |
| target_width, target_height = target_resolution |
| scale_w = target_width / original_width |
| scale_h = target_height / original_height |
|
|
| if scale_w < scale_h: |
| new_width = target_width |
| new_height = min(math.ceil(original_height * scale_w), target_height) |
| else: |
| new_height = target_height |
| new_width = min(math.ceil(original_width * scale_h), target_width) |
|
|
| resized_image = image.resize((new_width, new_height)) |
| new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) |
| paste_x = (target_width - new_width) // 2 |
| paste_y = (target_height - new_height) // 2 |
| new_image.paste(resized_image, (paste_x, paste_y)) |
| return new_image |
|
|
|
|
| def select_best_resolution( |
| original_size: Tuple[int, int], |
| possible_resolutions: List[Tuple[int, int]], |
| ) -> Tuple[int, int]: |
| """Select the best resolution based on aspect ratio and minimal waste. |
| |
| Args: |
| original_size: Original image size (width, height). |
| possible_resolutions: Candidate resolutions. |
| |
| Returns: |
| Best resolution (width, height). |
| """ |
| original_width, original_height = original_size |
| best_fit = None |
| max_effective_resolution = 0 |
| min_wasted_resolution = float("inf") |
|
|
| for width, height in possible_resolutions: |
| scale = min(width / original_width, height / original_height) |
| scaled_width, scaled_height = ( |
| int(original_width * scale), |
| int(original_height * scale), |
| ) |
| effective_resolution = min( |
| scaled_width * scaled_height, original_width * original_height |
| ) |
| wasted_resolution = (width * height) - effective_resolution |
| if effective_resolution > max_effective_resolution or ( |
| effective_resolution == max_effective_resolution |
| and wasted_resolution < min_wasted_resolution |
| ): |
| max_effective_resolution = effective_resolution |
| min_wasted_resolution = wasted_resolution |
| best_fit = (width, height) |
| return best_fit |
|
|
|
|
| def divide_to_patches( |
| image: Image.Image, patch_size: int |
| ) -> List[Image.Image]: |
| """Divide an image into square patches. |
| |
| Args: |
| image: Input PIL image. |
| patch_size: Patch size in pixels. |
| |
| Returns: |
| List of patch images. |
| """ |
| patches = [] |
| width, height = image.size |
| for i in range(0, height, patch_size): |
| for j in range(0, width, patch_size): |
| box = (j, i, j + patch_size, i + patch_size) |
| patches.append(image.crop(box)) |
| return patches |
|
|
|
|
| def process_anyres_image( |
| image: Image.Image, |
| size: int = 512, |
| grid_pinpoints: List[Tuple[int, int]] = None, |
| ) -> List[Image.Image]: |
| """Process an image into a list of tiles for LLaVA-Next style tiling. |
| |
| Args: |
| image: Input PIL image. |
| size: Base tile size. |
| grid_pinpoints: Candidate grid pinpoints. |
| |
| Returns: |
| List of tiled images (original resize + tiles). |
| """ |
| if grid_pinpoints is None: |
| grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)] |
| possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints] |
| best_resolution = select_best_resolution(image.size, possible_resolutions) |
| image_padded = resize_and_pad_image(image, best_resolution) |
| patches = divide_to_patches(image_padded, size) |
| image_original_resize = image.resize((size, size)) |
| return [image_original_resize] + patches |
|
|
|
|
| def estimate_num_tiles_llava_next( |
| image_size: Tuple[int, int], |
| size: int = 512, |
| grid_pinpoints: List[Tuple[int, int]] = None, |
| ) -> int: |
| """Estimate tile count for LLaVA-Next tiling without decoding images.""" |
| if grid_pinpoints is None: |
| grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)] |
| possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints] |
| best_resolution = select_best_resolution(image_size, possible_resolutions) |
| grid_x = int(best_resolution[0] / size) |
| grid_y = int(best_resolution[1] / size) |
| return 1 + (grid_x * grid_y) |
|
|
|
|
| def split_to_patches( |
| image: Image.Image, grid: Tuple[int, int] |
| ) -> List[Image.Image]: |
| """Divide an image into patches using a fixed grid. |
| |
| Args: |
| image: Input PIL image. |
| grid: Grid dimensions (grid_x, grid_y). |
| |
| Returns: |
| List of patch images. |
| """ |
| patches = [] |
| width, height = image.size |
| grid_x = int(width / grid[0]) |
| grid_y = int(height / grid[1]) |
| for i in range(0, height, grid_y): |
| for j in range(0, width, grid_x): |
| box = (j, i, j + grid_x, i + grid_y) |
| patches.append(image.crop(box)) |
| return patches |
|
|
|
|
| def ensure_divide(length: float, patch_size: int) -> int: |
| """Round length up to a multiple of patch_size. |
| |
| Args: |
| length: Raw length to align. |
| patch_size: Patch size to align to. |
| |
| Returns: |
| Length aligned to patch_size. |
| """ |
| return max(round(length / patch_size) * patch_size, patch_size) |
|
|
|
|
| def find_best_resize( |
| original_size: Tuple[int, int], |
| scale_resolution: int, |
| patch_size: int, |
| allow_upscale: bool = False, |
| ) -> Tuple[int, int]: |
| """Find the best resize dimensions for UHD tiling. |
| |
| Args: |
| original_size: Original image size (width, height). |
| scale_resolution: Target scale resolution. |
| patch_size: Patch size for alignment. |
| allow_upscale: Whether to allow upscaling. |
| |
| Returns: |
| Best resized (width, height). |
| """ |
| width, height = original_size |
| if (width * height > scale_resolution * scale_resolution) or allow_upscale: |
| aspect_ratio = width / height |
| height = int(scale_resolution / math.sqrt(aspect_ratio)) |
| width = int(height * aspect_ratio) |
| best_width = ensure_divide(width, patch_size) |
| best_height = ensure_divide(height, patch_size) |
| return (best_width, best_height) |
|
|
|
|
| def get_refine_size( |
| original_size: Tuple[int, int], |
| grid: Tuple[int, int], |
| scale_resolution: int, |
| patch_size: int, |
| allow_upscale: bool = False, |
| ) -> Tuple[int, int]: |
| """Compute the refined resize based on a tile grid. |
| |
| Args: |
| original_size: Original image size (width, height). |
| grid: Tile grid (grid_x, grid_y). |
| scale_resolution: Target scale resolution. |
| patch_size: Patch size for alignment. |
| allow_upscale: Whether to allow upscaling. |
| |
| Returns: |
| Refined resize (width, height). |
| """ |
| width, height = original_size |
| grid_x, grid_y = grid |
| refine_width = ensure_divide(width, grid_x) |
| refine_height = ensure_divide(height, grid_y) |
| grid_width = refine_width / grid_x |
| grid_height = refine_height / grid_y |
| best_grid_size = find_best_resize( |
| (grid_width, grid_height), |
| scale_resolution, |
| patch_size, |
| allow_upscale=allow_upscale, |
| ) |
| return (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) |
|
|
|
|
| def process_anyres_image_uhd( |
| image: Image.Image, |
| max_tiles_num: int = 9, |
| scale_resolution: int = 448, |
| patch_size: int = 4, |
| never_split: bool = False, |
| ) -> List[Image.Image]: |
| """Process an image into tiles for LLaVA-UHD style tiling. |
| |
| Args: |
| image: Input PIL image. |
| max_tiles_num: Maximum number of tiles to generate. |
| scale_resolution: Target resolution for scaling. |
| patch_size: Patch size for alignment. |
| never_split: Whether to avoid splitting into tiles. |
| |
| Returns: |
| List of tiles (patches + resized source image). |
| """ |
| original_width, original_height = image.size |
| log_ratio = math.log(original_width / original_height) |
| ratio = (original_width * original_height) / ( |
| scale_resolution * scale_resolution |
| ) |
| multiple = min(math.ceil(ratio), max_tiles_num) |
| patches = [] |
|
|
| if multiple <= 1 or never_split: |
| best_size = find_best_resize( |
| image.size, scale_resolution, patch_size, allow_upscale=True |
| ) |
| source_image = image.resize(best_size, Image.Resampling.BICUBIC) |
| return [source_image] |
|
|
| candidate_split_grids_nums = [] |
| for i in [multiple - 1, multiple, multiple + 1]: |
| if i == 1 or i > max_tiles_num: |
| continue |
| candidate_split_grids_nums.append(i) |
|
|
| best_resize = find_best_resize(image.size, scale_resolution, patch_size) |
| source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) |
| candidate_grids = [] |
| for split_grids_nums in candidate_split_grids_nums: |
| m = 1 |
| while m <= split_grids_nums: |
| if split_grids_nums % m == 0: |
| candidate_grids.append([m, split_grids_nums // m]) |
| m += 1 |
|
|
| best_grid = [1, 1] |
| min_error = float("inf") |
| for grid in candidate_grids: |
| error = abs(log_ratio - math.log(grid[0] / grid[1])) |
| if error < min_error: |
| best_grid = grid |
| min_error = error |
|
|
| refine_size = get_refine_size( |
| image.size, |
| (best_grid[0], best_grid[1]), |
| scale_resolution, |
| patch_size, |
| allow_upscale=True, |
| ) |
| refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) |
| patches = split_to_patches(refine_image, (best_grid[0], best_grid[1])) |
| return patches + [source_image] |
|
|
|
|
| def estimate_num_tiles_llava_uhd( |
| image_size: Tuple[int, int], |
| max_tiles_num: int = 9, |
| scale_resolution: int = 448, |
| patch_size: int = 4, |
| never_split: bool = False, |
| ) -> int: |
| """Estimate tile count for LLaVA-UHD tiling without decoding images.""" |
| original_width, original_height = image_size |
| log_ratio = math.log(original_width / original_height) |
| ratio = (original_width * original_height) / ( |
| scale_resolution * scale_resolution |
| ) |
| multiple = min(math.ceil(ratio), max_tiles_num) |
| if multiple <= 1 or never_split: |
| return 1 |
|
|
| candidate_split_grids_nums = [] |
| for i in [multiple - 1, multiple, multiple + 1]: |
| if i == 1 or i > max_tiles_num: |
| continue |
| candidate_split_grids_nums.append(i) |
|
|
| candidate_grids = [] |
| for split_grids_nums in candidate_split_grids_nums: |
| m = 1 |
| while m <= split_grids_nums: |
| if split_grids_nums % m == 0: |
| candidate_grids.append([m, split_grids_nums // m]) |
| m += 1 |
|
|
| best_grid = [1, 1] |
| min_error = float("inf") |
| for grid in candidate_grids: |
| error = abs(log_ratio - math.log(grid[0] / grid[1])) |
| if error < min_error: |
| best_grid = grid |
| min_error = error |
|
|
| return (best_grid[0] * best_grid[1]) + 1 |
|
|
|
|
| Yasa2ImageProcessor.register_for_auto_class() |
|
|