import math
import numpy as np
import torch
import pyvips

from typing import TypedDict


def select_tiling(
    height: int, width: int, crop_size: int, max_crops: int
) -> tuple[int, int]:
    """
    Determine the optimal number of tiles to cover an image with overlapping crops.
    """
    if height <= crop_size or width <= crop_size:
        return (1, 1)

    # Minimum required tiles in each dimension
    min_h = math.ceil(height / crop_size)
    min_w = math.ceil(width / crop_size)

    # If minimum required tiles exceed max_crops, return proportional distribution
    if min_h * min_w > max_crops:
        ratio = math.sqrt(max_crops / (min_h * min_w))
        return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))

    # Perfect aspect-ratio tiles that satisfy max_crops
    h_tiles = math.floor(math.sqrt(max_crops * height / width))
    w_tiles = math.floor(math.sqrt(max_crops * width / height))

    # Ensure we meet minimum tile requirements
    h_tiles = max(h_tiles, min_h)
    w_tiles = max(w_tiles, min_w)

    # If we exceeded max_crops, scale down the larger dimension
    if h_tiles * w_tiles > max_crops:
        if w_tiles > h_tiles:
            w_tiles = math.floor(max_crops / h_tiles)
        else:
            h_tiles = math.floor(max_crops / w_tiles)

    return (max(1, h_tiles), max(1, w_tiles))


class OverlapCropOutput(TypedDict):
    crops: np.ndarray
    tiling: tuple[int, int]


def overlap_crop_image(
    image: np.ndarray,
    overlap_margin: int,
    max_crops: int,
    base_size: tuple[int, int] = (378, 378),
    patch_size: int = 14,
) -> OverlapCropOutput:
    """
    Process an image using an overlap-and-resize cropping strategy with margin handling.

    This function takes an input image and creates multiple overlapping crops with
    consistent margins. It produces:
    1. A single global crop resized to base_size
    2. Multiple overlapping local crops that maintain high resolution details
    3. A patch ordering matrix that tracks correspondence between crops

    The overlap strategy ensures:
    - Smooth transitions between adjacent crops
    - No loss of information at crop boundaries
    - Proper handling of features that cross crop boundaries
    - Consistent patch indexing across the full image

    Args:
        image (np.ndarray): Input image as numpy array with shape (H,W,C)
        base_size (tuple[int,int]): Target size for crops, default (378,378)
        patch_size (int): Size of patches in pixels, default 14
        overlap_margin (int): Margin size in patch units, default 4
        max_crops (int): Maximum number of crops allowed, default 12

    Returns:
        OverlapCropOutput: Dictionary containing:
            - crops: A numpy array containing the global crop of the full image (index 0)
                followed by the overlapping cropped regions (indices 1+)
            - tiling: Tuple of (height,width) tile counts
    """
    original_h, original_w = image.shape[:2]

    # Convert margin from patch units to pixels
    margin_pixels = patch_size * overlap_margin
    total_margin_pixels = margin_pixels * 2  # Both sides

    # Calculate crop parameters
    crop_patches = base_size[0] // patch_size  # patches per crop dimension
    crop_window_patches = crop_patches - (2 * overlap_margin)  # usable patches
    crop_window_size = crop_window_patches * patch_size  # usable size in pixels

    # Determine tiling
    tiling = select_tiling(
        original_h - total_margin_pixels,
        original_w - total_margin_pixels,
        crop_window_size,
        max_crops,
    )

    # Pre-allocate crops.
    n_crops = tiling[0] * tiling[1] + 1  # 1 = global crop
    crops = np.zeros(
        (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
    )

    # Resize image to fit tiling
    target_size = (
        tiling[0] * crop_window_size + total_margin_pixels,
        tiling[1] * crop_window_size + total_margin_pixels,
    )

    # Convert to vips for resizing
    vips_image = pyvips.Image.new_from_array(image)
    scale_x = target_size[1] / image.shape[1]
    scale_y = target_size[0] / image.shape[0]
    resized = vips_image.resize(scale_x, vscale=scale_y)
    image = resized.numpy()

    # Create global crop
    scale_x = base_size[1] / vips_image.width
    scale_y = base_size[0] / vips_image.height
    global_vips = vips_image.resize(scale_x, vscale=scale_y)
    crops[0] = global_vips.numpy()

    for i in range(tiling[0]):
        for j in range(tiling[1]):
            # Calculate crop coordinates
            y0 = i * crop_window_size
            x0 = j * crop_window_size

            # Extract crop with padding if needed
            y_end = min(y0 + base_size[0], image.shape[0])
            x_end = min(x0 + base_size[1], image.shape[1])

            crop_region = image[y0:y_end, x0:x_end]
            crops[
                1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
            ] = crop_region

    return {"crops": crops, "tiling": tiling}


def reconstruct_from_crops(
    crops: torch.Tensor,
    tiling: tuple[int, int],
    overlap_margin: int,
    patch_size: int = 14,
) -> torch.Tensor:
    """
    Reconstruct the original image from overlapping crops into a single seamless image.

    Takes a list of overlapping image crops along with their positional metadata and
    reconstructs them into a single coherent image by carefully stitching together
    non-overlapping regions. Handles both numpy arrays and PyTorch tensors.

    Args:
        crops: List of image crops as numpy arrays or PyTorch tensors with shape
            (H,W,C)
        tiling: Tuple of (height,width) indicating crop grid layout
        patch_size: Size in pixels of each patch, default 14
        overlap_margin: Number of overlapping patches on each edge, default 4

    Returns:
        Reconstructed image as numpy array or PyTorch tensor matching input type,
        with shape (H,W,C) where H,W are the original image dimensions
    """
    tiling_h, tiling_w = tiling
    crop_height, crop_width = crops[0].shape[:2]
    margin_pixels = overlap_margin * patch_size

    # Calculate output size (only adding margins once)
    output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
    output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels

    reconstructed = torch.zeros(
        (output_h, output_w, crops[0].shape[2]),
        device=crops[0].device,
        dtype=crops[0].dtype,
    )

    for i, crop in enumerate(crops):
        tile_y = i // tiling_w
        tile_x = i % tiling_w

        # For each tile, determine which part to keep
        # Keep left margin only for first column
        x_start = 0 if tile_x == 0 else margin_pixels
        # Keep right margin only for last column
        x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
        # Keep top margin only for first row
        y_start = 0 if tile_y == 0 else margin_pixels
        # Keep bottom margin only for last row
        y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels

        # Calculate where this piece belongs in the output
        out_x = tile_x * (crop_width - 2 * margin_pixels)
        out_y = tile_y * (crop_height - 2 * margin_pixels)

        # Place the piece
        reconstructed[
            out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
        ] = crop[y_start:y_end, x_start:x_end]

    return reconstructed