HWI-ADS-v1 / inference.py
Addax-Data-Science's picture
Upload 2 files
4271575 verified
"""
ModelInference for the addax-sppnet model family.
Architecture: SpeciesNet GraphModule backbone (frozen) + a thin nn.Linear
head fine-tuned per region. Originally written for AddaxAI's legacy
classify_detections.py (Peter van Lunteren, 13 May 2025); ported here to
the WebUI's class-based ModelInference interface.
Files expected in the model directory:
- <model_fname>.pt fine-tuned head checkpoint, e.g. final-20260317.pt
- <backbone>.pt frozen SpeciesNet backbone, one of:
- always_crop_99710272_22x8_v12_epoch_00148.pt
- full_image_88545560_22x8_v12_epoch_00153.pt
"""
from __future__ import annotations
# Allow loading checkpoints saved on a Windows runner on a POSIX machine.
import pathlib
import platform
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
if platform.system() != "Windows":
pathlib.WindowsPath = pathlib.PosixPath # type: ignore[assignment]
# Don't fail on truncated images during inference.
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
_BACKBONE_FILENAMES = (
"always_crop_99710272_22x8_v12_epoch_00148.pt",
"full_image_88545560_22x8_v12_epoch_00153.pt",
)
def _load_fx_checkpoint(weights_path: Path, map_location: str = "cpu") -> nn.Module:
"""Load a SpeciesNet onnx2torch GraphModule.
The backbone is shipped as a torch.fx GraphModule. PyTorch 2.4+
requires `reduce_graph_module` to be in the safe-globals allowlist
when loading with `weights_only=True`; older versions don't have
this concept. Try both paths.
"""
try:
from torch.fx.graph_module import reduce_graph_module
from torch.serialization import add_safe_globals
add_safe_globals([reduce_graph_module])
except Exception:
pass
try:
obj = torch.load(weights_path, map_location=map_location, weights_only=True)
except Exception:
obj = torch.load(weights_path, map_location=map_location, weights_only=False)
if hasattr(obj, "state_dict") and hasattr(obj, "forward"):
return obj
raise ValueError(f"{weights_path} is not a torch.nn.Module GraphModule")
class _FXClassifier(nn.Module):
"""SpeciesNet backbone (frozen) + linear head."""
def __init__(
self,
backbone: nn.Module,
num_classes: int,
img_size: int = 480,
input_layout: str = "nhwc",
) -> None:
super().__init__()
self.backbone = backbone
self.input_layout = input_layout.lower()
for p in self.backbone.parameters():
p.requires_grad = False
self.backbone.eval()
# Probe the backbone to discover output feature size at this
# img_size + layout combo, so the head matches exactly.
with torch.no_grad():
x = torch.zeros(1, 3, img_size, img_size)
if self.input_layout == "nhwc":
x = x.permute(0, 2, 3, 1).contiguous()
z = self.backbone(x)
z = self._pool(z)
in_features = z.shape[1]
self.head = nn.Linear(in_features, num_classes)
@staticmethod
def _pool(z: torch.Tensor) -> torch.Tensor:
if z.ndim == 4:
return F.adaptive_avg_pool2d(z, 1).flatten(1)
if z.ndim == 3:
return z.mean(dim=1)
return z.flatten(1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.input_layout == "nhwc":
x = x.permute(0, 2, 3, 1).contiguous()
z = self.backbone(x)
z = self._pool(z)
return self.head(z)
class ModelInference:
"""ModelInference for the addax-sppnet family (SpeciesNet backbone + linear head)."""
def __init__(self, model_dir: Path, model_path: Path) -> None:
self.model_dir = Path(model_dir)
self.model_path = Path(model_path)
self.model: _FXClassifier | None = None
self.device: torch.device | None = None
self._class_names: list[str] = []
self._preprocess: transforms.Compose | None = None
# ------------------------------------------------------------------
# Required interface
# ------------------------------------------------------------------
def check_gpu(self) -> bool:
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
return torch.cuda.is_available()
def load_model(self) -> None:
if self.check_gpu():
self.device = torch.device(
"mps" if torch.backends.mps.is_available() else "cuda"
)
else:
self.device = torch.device("cpu")
# Load fine-tuned head checkpoint.
try:
checkpoint = torch.load(
self.model_path, map_location=self.device, weights_only=True
)
except Exception:
checkpoint = torch.load(
self.model_path, map_location=self.device, weights_only=False
)
# Resolve backbone path. The fine-tuned model ships alongside
# one of two known backbone files, depending on the recipe.
backbone_path: Path | None = None
for name in _BACKBONE_FILENAMES:
candidate = self.model_dir / name
if candidate.exists():
backbone_path = candidate
break
if backbone_path is None:
raise FileNotFoundError(
"Backbone weights not found. Expected one of "
f"{_BACKBONE_FILENAMES} in {self.model_dir}."
)
backbone = _load_fx_checkpoint(backbone_path, map_location="cpu")
model = _FXClassifier(
backbone=backbone,
num_classes=checkpoint["num_classes"],
img_size=checkpoint["img_size"],
input_layout=checkpoint["input_layout"],
)
model.load_state_dict(checkpoint["model"])
self.model = model.to(self.device).eval()
self._class_names = list(checkpoint["class_names"])
norm = checkpoint["normalize"]
img_size = checkpoint["img_size"]
self._preprocess = transforms.Compose([
transforms.Resize((img_size, img_size), antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=norm["mean"], std=norm["std"]),
])
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""Crop the bbox region. SpeciesNet head was trained on tight crops."""
W, H = image.size
x, y, w, h = bbox
left = max(0, int(round(x * W)))
top = max(0, int(round(y * H)))
right = min(W, int(round((x + w) * W)))
bottom = min(H, int(round((y + h) * H)))
if right <= left or bottom <= top:
return image
return image.crop((left, top, right, bottom))
def get_classification(self, crop: Image.Image) -> list[list]:
"""Per-image inference. Returns [[name, prob], ...] for all classes."""
assert self.model is not None and self._preprocess is not None
if crop.mode != "RGB":
crop = crop.convert("RGB")
tensor = self._preprocess(crop).unsqueeze(0).to(self.device)
with torch.no_grad():
probs = F.softmax(self.model(tensor), dim=1).cpu().numpy()[0]
return [[self._class_names[i], float(probs[i])] for i in range(len(probs))]
def get_class_names(self) -> dict[str, str]:
"""1-indexed mapping {id: class_name} for the output JSON."""
return {str(i + 1): name for i, name in enumerate(self._class_names)}
# ------------------------------------------------------------------
# Optional batch interface (5-15x GPU speedup vs per-crop calls)
# ------------------------------------------------------------------
def get_tensor(self, crop: Image.Image) -> np.ndarray:
assert self._preprocess is not None
if crop.mode != "RGB":
crop = crop.convert("RGB")
return self._preprocess(crop).numpy()
def classify_batch(self, batch: np.ndarray) -> list[list[list]]:
assert self.model is not None
tensor = torch.from_numpy(batch).to(self.device)
with torch.no_grad():
probs = F.softmax(self.model(tensor), dim=1).cpu().numpy()
return [
[[self._class_names[j], float(p[j])] for j in range(len(p))]
for p in probs
]