| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict, List, Optional, Sequence |
|
|
| import torch |
| from PIL import Image |
|
|
| from .config import ProtoMorphConfig |
| from .model import ProtoMorphDINOv3 |
|
|
|
|
| def load_image(path: str | Path) -> Image.Image: |
| return Image.open(path).convert("RGB") |
|
|
|
|
| def load_labels(path: Optional[str | Path], num_classes: int) -> List[str]: |
| if path is None: |
| return [f"class_{i}" for i in range(num_classes)] |
| p = Path(path) |
| if p.suffix.lower() == ".json": |
| data = json.loads(p.read_text()) |
| if isinstance(data, dict): |
| return [data.get(str(i), data.get(i, f"class_{i}")) for i in range(num_classes)] |
| return list(data) |
| labels = [line.strip() for line in p.read_text().splitlines() if line.strip()] |
| if len(labels) < num_classes: |
| labels += [f"class_{i}" for i in range(len(labels), num_classes)] |
| return labels[:num_classes] |
|
|
|
|
| def build_model( |
| config_path: str | Path, |
| checkpoint_path: Optional[str | Path], |
| device: str = "cuda", |
| local_files_only: bool = False, |
| allow_random_head: bool = False, |
| ) -> ProtoMorphDINOv3: |
| cfg = ProtoMorphConfig.from_json(config_path) |
| device_obj = torch.device(device if torch.cuda.is_available() or device == "cpu" else "cpu") |
| model = ProtoMorphDINOv3(cfg, local_files_only=local_files_only).to(device_obj).eval() |
| if checkpoint_path is not None and Path(checkpoint_path).exists(): |
| model.load_custom_head(checkpoint_path) |
| elif not allow_random_head: |
| raise FileNotFoundError( |
| f"Missing custom-head checkpoint: {checkpoint_path}. " |
| "Pass --allow-random-head only for smoke tests; random logits are not meaningful." |
| ) |
| return model |
|
|
|
|
| @torch.no_grad() |
| def predict_paths( |
| model: ProtoMorphDINOv3, |
| image_paths: Sequence[str | Path], |
| labels: List[str], |
| topk: int = 5, |
| device: str = "cuda", |
| force_hard: bool = False, |
| ) -> List[Dict]: |
| images = [load_image(p) for p in image_paths] |
| out = model(images, device=device, force_hard=force_hard) |
| probs = out["logits"].softmax(dim=-1).float().cpu() |
| main_probs = out["main_logits"].softmax(dim=-1).float().cpu() |
| hard_mask = out["hard_mask"].cpu().tolist() |
| gate_pmax = out["gate_pmax"].float().cpu().tolist() |
| gate_margin = out["gate_margin"].float().cpu().tolist() |
| gate_entropy = out["gate_entropy"].float().cpu().tolist() |
|
|
| results: List[Dict] = [] |
| for i, path in enumerate(image_paths): |
| k = min(topk, probs.shape[-1]) |
| values, indices = probs[i].topk(k) |
| main_values, main_indices = main_probs[i].topk(k) |
| results.append( |
| { |
| "image": str(path), |
| "hard_case": bool(hard_mask[i]), |
| "gate": { |
| "pmax": float(gate_pmax[i]), |
| "margin": float(gate_margin[i]), |
| "entropy": float(gate_entropy[i]), |
| }, |
| "topk": [ |
| {"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)} |
| for r, (idx, val) in enumerate(zip(indices.tolist(), values.tolist())) |
| ], |
| "main_topk": [ |
| {"rank": r + 1, "class_id": int(idx), "label": labels[int(idx)], "prob": float(val)} |
| for r, (idx, val) in enumerate(zip(main_indices.tolist(), main_values.tolist())) |
| ], |
| "patch_hw": out["patch_hw"], |
| "pixel_hw": out["pixel_hw"], |
| } |
| ) |
| return results |
|
|