|
|
from __future__ import annotations |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
def pil_to_np(img: Image.Image) -> np.ndarray: |
|
|
"""Convert a PIL image into a float32 NumPy array in the [0, 1] range.""" |
|
|
if img.mode not in ("RGB", "RGBA", "L"): |
|
|
img = img.convert("RGB") |
|
|
if img.mode == "L": |
|
|
img = img.convert("RGB") |
|
|
arr = np.asarray(img).astype(np.float32) |
|
|
if arr.ndim == 2: |
|
|
arr = np.repeat(arr[..., None], 3, axis=2) |
|
|
if arr.shape[2] == 4: |
|
|
arr = arr[..., :3] |
|
|
return arr / 255.0 |
|
|
|
|
|
def np_to_pil(arr: np.ndarray) -> Image.Image: |
|
|
"""Convert a float32 NumPy array (0–1) into an RGB PIL image.""" |
|
|
return Image.fromarray(np.clip(arr * 255.0, 0, 255).astype(np.uint8)) |
|
|
|
|
|
def resize_and_crop_to_grid(img: Image.Image, width: int, height: int, grid: int) -> Image.Image: |
|
|
"""Resize and center-crop an image so both dimensions are multiples of the grid.""" |
|
|
img = img.convert("RGB").resize((width, height), Image.LANCZOS) |
|
|
H, W = img.height, img.width |
|
|
H2, W2 = (H // grid) * grid, (W // grid) * grid |
|
|
if H2 != H or W2 != W: |
|
|
left = (W - W2) // 2 |
|
|
top = (H - H2) // 2 |
|
|
img = img.crop((left, top, left + W2, top + H2)) |
|
|
return img |
|
|
|
|
|
def block_view(arr: np.ndarray, bh: int, bw: int) -> np.ndarray: |
|
|
"""Return a strided view that exposes the image as (grid_h, grid_w, bh, bw, C) blocks.""" |
|
|
H, W, C = arr.shape |
|
|
if H % bh or W % bw: |
|
|
raise ValueError("Array dimensions must be divisible by the block size") |
|
|
shape = (H // bh, W // bw, bh, bw, C) |
|
|
strides = (arr.strides[0] * bh, arr.strides[1] * bw, arr.strides[0], arr.strides[1], arr.strides[2]) |
|
|
return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides) |
|
|
|
|
|
def cell_means(arr: np.ndarray, grid: int) -> np.ndarray: |
|
|
"""Return weighted mean RGB values for every grid cell using pure NumPy ops.""" |
|
|
H, W, _ = arr.shape |
|
|
bh, bw = H // grid, W // grid |
|
|
blocks = block_view(arr, bh, bw) |
|
|
|
|
|
center_h = (bh - 1) / 2.0 |
|
|
center_w = (bw - 1) / 2.0 |
|
|
yy, xx = np.meshgrid(np.arange(bh), np.arange(bw), indexing="ij") |
|
|
dist = np.sqrt((yy - center_h) ** 2 + (xx - center_w) ** 2) |
|
|
max_dist = np.sqrt(center_h**2 + center_w**2) or 1.0 |
|
|
weights = 1.0 - (dist / max_dist) * 0.5 |
|
|
weights = weights.astype(np.float32) |
|
|
weights /= weights.sum() |
|
|
|
|
|
weighted = blocks * weights[None, None, :, :, None] |
|
|
return weighted.sum(axis=(2, 3)) |
|
|
|