| | """ |
| | load.py |
| | |
| | Module for loading ensemble models (STAC compatible) and performing |
| | optimized inference on large geospatial imagery using dynamic batching |
| | and Gaussian blending. |
| | """ |
| |
|
| | import math |
| | import pathlib |
| | import itertools |
| | from typing import Literal, Tuple, List |
| |
|
| | import torch |
| | import torch.nn |
| | import numpy as np |
| | import pystac |
| | from torch.utils.data import Dataset, DataLoader |
| | from tqdm import tqdm |
| |
|
| | |
| | |
| | |
| |
|
| | class EnsembleModel(torch.nn.Module): |
| | """ |
| | Runtime ensemble model for combining multiple model outputs. |
| | Used when loading multiple separate .pt2 files. |
| | """ |
| | def __init__(self, *models, mode="max"): |
| | super(EnsembleModel, self).__init__() |
| | self.models = torch.nn.ModuleList(models) |
| | self.mode = mode |
| | if mode not in ["min", "mean", "median", "max", "none"]: |
| | raise ValueError("Mode must be 'none', 'min', 'mean', 'median', or 'max'.") |
| |
|
| | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Returns: |
| | - probabilities: (B, 1, H, W) |
| | - uncertainty: (B, 1, H, W) (normalized std dev) |
| | """ |
| | outputs = [model(x) for model in self.models] |
| | |
| | if not outputs: |
| | return None, None |
| | |
| | |
| | stacked = torch.stack(outputs, dim=1).squeeze(2) |
| | |
| | |
| | if self.mode == "max": |
| | probs = torch.max(stacked, dim=1, keepdim=True)[0] |
| | elif self.mode == "mean": |
| | probs = torch.mean(stacked, dim=1, keepdim=True) |
| | elif self.mode == "median": |
| | probs = torch.median(stacked, dim=1, keepdim=True)[0] |
| | elif self.mode == "min": |
| | probs = torch.min(stacked, dim=1, keepdim=True)[0] |
| | elif self.mode == "none": |
| | return stacked, None |
| | |
| | |
| | N = len(outputs) |
| | if N > 1: |
| | std = torch.std(stacked, dim=1, keepdim=True) |
| | std_max = math.sqrt(0.25 * N / (N - 1)) |
| | uncertainty = torch.clamp(std / std_max, 0.0, 1.0) |
| | else: |
| | uncertainty = torch.zeros_like(probs) |
| | |
| | return probs, uncertainty |
| |
|
| | def get_spline_window(window_size: int, power: int = 2) -> np.ndarray: |
| | """Generates a 2D Hann window for smoothing tile edges.""" |
| | intersection = np.hanning(window_size) |
| | window_2d = np.outer(intersection, intersection) |
| | return (window_2d ** power).astype(np.float32) |
| |
|
| | def fix_lastchunk(iterchunks, s2dim, chunk_size): |
| | """Adjusts the last chunks to fit within image boundaries.""" |
| | itercontainer = [] |
| | for index_i, index_j in iterchunks: |
| | if index_i + chunk_size > s2dim[0]: |
| | index_i = max(s2dim[0] - chunk_size, 0) |
| | if index_j + chunk_size > s2dim[1]: |
| | index_j = max(s2dim[1] - chunk_size, 0) |
| | itercontainer.append((index_i, index_j)) |
| | return list(set(itercontainer)) |
| |
|
| | def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0): |
| | """Generates top-left coordinates for sliding window inference.""" |
| | dimy, dimx = dimension |
| | if chunk_size > max(dimx, dimy): |
| | return [(0, 0)] |
| | |
| | y_step = chunk_size - overlap |
| | x_step = chunk_size - overlap |
| | |
| | iterchunks = list(itertools.product( |
| | range(0, dimy, y_step), |
| | range(0, dimx, x_step) |
| | )) |
| | |
| | return fix_lastchunk(iterchunks, dimension, chunk_size) |
| |
|
| | |
| | |
| | |
| |
|
| | class PatchDataset(Dataset): |
| | """ |
| | Dataset wrapper to handle image slicing and padding on CPU workers. |
| | """ |
| | def __init__(self, image: np.ndarray, coords: List[Tuple[int, int]], chunk_size: int, nodata: float = 0): |
| | self.image = image |
| | self.coords = coords |
| | self.chunk_size = chunk_size |
| | self.nodata = nodata |
| |
|
| | def __len__(self): |
| | return len(self.coords) |
| |
|
| | def __getitem__(self, idx): |
| | row_off, col_off = self.coords[idx] |
| | |
| | |
| | patch = self.image[:, row_off : row_off + self.chunk_size, col_off : col_off + self.chunk_size] |
| | c, h, w = patch.shape |
| |
|
| | patch_tensor = torch.from_numpy(patch).float() |
| |
|
| | |
| | pad_h = self.chunk_size - h |
| | pad_w = self.chunk_size - w |
| | if pad_h > 0 or pad_w > 0: |
| | patch_tensor = torch.nn.functional.pad(patch_tensor, (0, pad_w, 0, pad_h), "constant", self.nodata) |
| |
|
| | |
| | mask_nodata = (patch_tensor == self.nodata).all(dim=0) |
| | |
| | return patch_tensor, row_off, col_off, h, w, mask_nodata |
| |
|
| | |
| | |
| | |
| |
|
| | def compiled_model( |
| | path: pathlib.Path, |
| | stac_item: pystac.Item, |
| | mode: Literal["min", "mean", "median", "max"] = "max", |
| | *args, **kwargs |
| | ): |
| | """ |
| | Loads .pt2 model(s). Returns a single model or an EnsembleModel. |
| | Automatically unwraps ExportedProgram if possible. |
| | """ |
| | model_paths = sorted([ |
| | asset.href for key, asset in stac_item.assets.items() |
| | if asset.href.endswith(".pt2") |
| | ]) |
| | |
| | if not model_paths: |
| | raise ValueError("No .pt2 files found in STAC item assets.") |
| | |
| | |
| | def load_pt2(p): |
| | program = torch.export.load(p) |
| | return program.module() if hasattr(program, "module") else program |
| |
|
| | if len(model_paths) == 1: |
| | return load_pt2(model_paths[0]) |
| | else: |
| | models = [load_pt2(p) for p in model_paths] |
| | return EnsembleModel(*models, mode=mode) |
| |
|
| |
|
| | def predict_large( |
| | image: np.ndarray, |
| | model: torch.nn.Module, |
| | chunk_size: int = 512, |
| | overlap: int = 128, |
| | batch_size: int = 16, |
| | num_workers: int = 8, |
| | device: str = "cuda", |
| | nodata: float = 0.0 |
| | ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray: |
| | """ |
| | Optimized inference for large images using Dynamic Batching and Gaussian Blending. |
| | """ |
| | |
| | if image.ndim != 3: |
| | raise ValueError(f"Input image must be (C, H, W). Received {image.shape}") |
| | |
| | |
| | |
| | if hasattr(model, "module") and callable(model.module): |
| | try: |
| | unpacked = model.module() |
| | if isinstance(unpacked, torch.nn.Module): |
| | model = unpacked |
| | except Exception: |
| | pass |
| |
|
| | |
| | try: |
| | model.eval() |
| | for p in model.parameters(): p.requires_grad = False |
| | except: pass |
| | |
| | |
| | if isinstance(model, torch.nn.Module): |
| | model = model.to(device) |
| |
|
| | bands, height, width = image.shape |
| |
|
| | |
| | |
| | dummy = torch.randn(2, bands, chunk_size, chunk_size).to(device) |
| | with torch.no_grad(): |
| | out = model(dummy) |
| | is_ensemble = isinstance(out, tuple) and len(out) == 2 |
| |
|
| | |
| | out_probs = np.zeros((1, height, width), dtype=np.float32) |
| | count_map = np.zeros((1, height, width), dtype=np.float32) |
| | out_uncert = np.zeros((1, height, width), dtype=np.float32) if is_ensemble else None |
| |
|
| | |
| | window_spline = get_spline_window(chunk_size, power=2) |
| | window_tensor = torch.from_numpy(window_spline).to(device) |
| |
|
| | |
| | coords = define_iteration((height, width), chunk_size, overlap) |
| | dataset = PatchDataset(image, coords, chunk_size, nodata) |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | prefetch_factor=2, |
| | pin_memory=True |
| | ) |
| |
|
| | |
| | for batch in tqdm(loader, desc=f"Inference (Batch {batch_size})"): |
| | patches, r_offs, c_offs, h_actuals, w_actuals, nodata_masks = batch |
| | |
| | |
| | patches = patches.to(device, non_blocking=True) |
| | nodata_masks = nodata_masks.to(device, non_blocking=True) |
| |
|
| | |
| | with torch.no_grad(): |
| | if is_ensemble: |
| | probs, uncert = model(patches) |
| | else: |
| | probs = model(patches) |
| | uncert = None |
| |
|
| | |
| | if probs.ndim == 3: probs = probs.unsqueeze(1) |
| | if is_ensemble and uncert.ndim == 3: uncert = uncert.unsqueeze(1) |
| |
|
| | |
| | B = patches.size(0) |
| | batch_weights = window_tensor.unsqueeze(0).unsqueeze(0).repeat(B, 1, 1, 1) |
| | |
| | |
| | batch_weights[nodata_masks.unsqueeze(1)] = 0.0 |
| |
|
| | |
| | probs_weighted = probs * batch_weights |
| | if is_ensemble: |
| | uncert_weighted = uncert * batch_weights |
| |
|
| | |
| | probs_cpu = probs_weighted.cpu().numpy() |
| | weights_cpu = batch_weights.cpu().numpy() |
| | if is_ensemble: |
| | uncert_cpu = uncert_weighted.cpu().numpy() |
| |
|
| | |
| | for i in range(B): |
| | r, c = r_offs[i].item(), c_offs[i].item() |
| | h, w = h_actuals[i].item(), w_actuals[i].item() |
| |
|
| | |
| | valid_probs = probs_cpu[i, :, :h, :w] |
| | valid_weights = weights_cpu[i, :, :h, :w] |
| | |
| | out_probs[:, r:r+h, c:c+w] += valid_probs |
| | count_map[:, r:r+h, c:c+w] += valid_weights |
| |
|
| | if is_ensemble: |
| | valid_uncert = uncert_cpu[i, :, :h, :w] |
| | out_uncert[:, r:r+h, c:c+w] += valid_uncert |
| |
|
| | |
| | mask_zero = (count_map == 0) |
| | count_map[mask_zero] = 1.0 |
| | |
| | out_probs /= count_map |
| | out_probs[mask_zero] = nodata |
| |
|
| | if is_ensemble: |
| | out_uncert /= count_map |
| | out_uncert[mask_zero] = nodata |
| | return out_probs, out_uncert |
| | |
| | return out_probs |