| """ |
| predictor.py |
| ============ |
| Inference wrapper for WEO-SAS/sen2sr (SEN2SRLite RGBN x4). |
| |
| Super-resolves 4-band Sentinel-2 RGBN imagery from 10 m to 2.5 m (4x). |
| |
| Usage |
| ----- |
| predictor = SEN2SRPredictor("./sen2sr") |
| |
| # Array inference: (4, H, W) float32 in [0, 1] -> (4, H*4, W*4) float32 |
| sr = predictor.predict(image) |
| |
| # GeoTIFF pipeline (reads Sentinel-2 DN, writes SR GeoTIFF at 2.5 m) |
| predictor.predict_tif("s2_scene.tif", "s2_sr.tif", bands=[0, 1, 2, 3]) |
| |
| Requirements |
| ------------ |
| torch, numpy, rasterio, safetensors, sen2sr (pip install sen2sr) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import numpy as np |
| import torch |
| import rasterio |
|
|
|
|
| class SEN2SRPredictor: |
| """ |
| SEN2SRLite RGBN x4 predictor. |
| |
| Parameters |
| ---------- |
| local_dir : local path to a downloaded WEO-SAS/sen2sr model repo |
| device : torch device (auto-detected if None) |
| model : pre-built srmodel callable; bypasses weight loading (used by sen2sr_pt.py) |
| """ |
|
|
| def __init__( |
| self, |
| local_dir: str, |
| device: Optional[torch.device] = None, |
| model = None, |
| ): |
| local_dir = Path(local_dir) |
| with open(local_dir / "config.json") as f: |
| cfg = json.load(f) |
|
|
| self.local_dir = local_dir |
| self.in_channels = cfg["in_channels"] |
| self.out_channels = cfg["out_channels"] |
| self.scaling_factor = cfg["scaling_factor"] |
| self.patch_size = cfg["patch_size"] |
| self.overlap = cfg["overlap"] |
| self.p_low = cfg["p_low"] |
| self.p_high = cfg["p_high"] |
| self.normalization_factor = cfg["normalization_factor"] |
| self.description = cfg.get("description", "") |
|
|
| self.device = device or torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
|
|
| if model is not None: |
| self.model = model |
| else: |
| self._load_model(local_dir, cfg) |
|
|
| |
| |
| |
|
|
| def _load_model(self, local_dir: Path, cfg: dict) -> None: |
| try: |
| import safetensors.torch |
| from sen2sr.models.opensr_baseline.cnn import CNNSR |
| from sen2sr.models.tricks import HardConstraint |
| from sen2sr.nonreference import srmodel |
| except ImportError as exc: |
| raise ImportError( |
| "sen2sr and safetensors are required. " |
| "Install: pip install sen2sr safetensors" |
| ) from exc |
|
|
| device = self.device |
|
|
| weights = safetensors.torch.load_file(local_dir / cfg["weights_file"]) |
| sr_model = CNNSR( |
| cfg["in_channels"], |
| cfg["out_channels"], |
| cfg["feature_channels"], |
| cfg["scaling_factor"], |
| cfg["bias"], |
| cfg["train_mode"], |
| cfg["num_blocks"], |
| ) |
| sr_model.load_state_dict(weights) |
| sr_model.to(device).eval() |
| for p in sr_model.parameters(): |
| p.requires_grad = False |
|
|
| hc_weights = safetensors.torch.load_file(local_dir / cfg["hard_constraint_file"]) |
| hard_constraint = HardConstraint( |
| low_pass_mask=hc_weights["weights"].to(device), device=device |
| ) |
|
|
| self.model = srmodel(sr_model, hard_constraint, device) |
|
|
| |
| |
| |
|
|
| def predict(self, image: np.ndarray) -> np.ndarray: |
| """ |
| Run 4x super-resolution on a (C, H, W) float32 image. |
| |
| Uses sen2sr.predict_large for images larger than patch_size so that |
| tile boundaries are blended seamlessly. |
| |
| Parameters |
| ---------- |
| image : (C, H, W) float32, values in [0, 1] |
| C must equal in_channels (4 for RGBN) |
| |
| Returns |
| ------- |
| (C, H*4, W*4) float32 in the same radiometric range as the input |
| """ |
| if image.ndim != 3 or image.shape[0] != self.in_channels: |
| raise ValueError( |
| f"Expected ({self.in_channels}, H, W), got {image.shape}" |
| ) |
|
|
| try: |
| import sen2sr |
| except ImportError as exc: |
| raise ImportError("pip install sen2sr") from exc |
|
|
| X = torch.from_numpy(image).float().to(self.device) |
|
|
| if image.shape[1] <= self.patch_size and image.shape[2] <= self.patch_size: |
| with torch.no_grad(): |
| out = self.model(X.unsqueeze(0)).squeeze(0) |
| else: |
| out = sen2sr.predict_large( |
| model = self.model, |
| X = X, |
| overlap = self.overlap, |
| ) |
|
|
| return out.cpu().numpy() |
|
|
| |
| |
| |
|
|
| def predict_tif( |
| self, |
| input_path: str, |
| output_path: str, |
| bands: Optional[List[int]] = None, |
| ) -> None: |
| """ |
| Full GeoTIFF super-resolution pipeline. |
| |
| Reads bands from the input GeoTIFF, normalises Sentinel-2 DN to [0, 1] |
| (divides by normalization_factor if values suggest DN range, otherwise |
| leaves as-is), runs 4x SR, and writes the output GeoTIFF with the |
| geotransform pixel size divided by scaling_factor. |
| |
| Parameters |
| ---------- |
| input_path : path to input Sentinel-2 GeoTIFF |
| output_path : output path for the 2.5 m SR GeoTIFF |
| bands : 0-based band indices to read (default: [0, 1, 2, 3]) |
| """ |
| bands = bands or list(range(self.in_channels)) |
|
|
| with rasterio.open(input_path) as src: |
| arr = src.read([b + 1 for b in bands]).astype(np.float32) |
| profile = src.profile.copy() |
|
|
| |
| |
| if arr.max() > 2.0: |
| arr = np.clip(arr / self.normalization_factor, 0.0, 1.0) |
|
|
| print( |
| f"SR inference model=sen2sr input={arr.shape} " |
| f"factor={self.scaling_factor}x {input_path}" |
| ) |
|
|
| sr = self.predict(arr) |
|
|
| print( |
| f"Output shape {sr.shape} " |
| f"range [{sr.min():.4f}, {sr.max():.4f}]" |
| ) |
|
|
| tf = profile["transform"] |
| new_tf = tf * tf.scale(1.0 / self.scaling_factor, 1.0 / self.scaling_factor) |
| out_profile = profile.copy() |
| out_profile.update( |
| count = sr.shape[0], |
| height = sr.shape[1], |
| width = sr.shape[2], |
| dtype = "float32", |
| transform = new_tf, |
| compress = "lzw", |
| ) |
| out_profile.pop("photometric", None) |
|
|
| Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
| with rasterio.open(output_path, "w", **out_profile) as dst: |
| dst.write(sr) |
|
|
| sr_res = abs(tf.a) / self.scaling_factor |
| print(f"Written: {output_path} (res={sr_res:.4f} m)") |
|
|