Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from .models.cmx.builder_np_conf import myEncoderDecoder as TruForNetwork | |
| LOGGER = logging.getLogger(__name__) | |
| class TruForOutputs: | |
| """Lightweight container for TruFor inference outputs.""" | |
| tamper_map: np.ndarray | |
| confidence_map: Optional[np.ndarray] | |
| detection_score: Optional[float] | |
| class TruForBundledModel: | |
| """Loads the TruFor network from the vendored sources and runs inference.""" | |
| def __init__(self, weights_path: Path | str, device: str = "cpu") -> None: | |
| self.weights_path = Path(weights_path) | |
| if not self.weights_path.exists(): | |
| raise FileNotFoundError(f"TruFor weights missing at {self.weights_path}") | |
| try: | |
| self.device = torch.device(device) | |
| except RuntimeError as exc: # pragma: no cover - defensive path for invalid strings | |
| raise ValueError(f"Unsupported torch device '{device}'") from exc | |
| self.model = self._build_model().to(self.device) | |
| self.model.eval() | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def predict(self, image: Image.Image) -> TruForOutputs: | |
| if image is None: | |
| raise ValueError("An input image is required for TruFor inference.") | |
| tensor = self._prepare_tensor(image).to(self.device) | |
| with torch.inference_mode(): | |
| pred, conf, det, _ = self.model(tensor) | |
| tamper_map = torch.softmax(pred[0], dim=0)[1].cpu().numpy() | |
| confidence_map: Optional[np.ndarray] = None | |
| if conf is not None: | |
| confidence_map = torch.sigmoid(conf[0][0]).cpu().numpy() | |
| detection_score: Optional[float] = None | |
| if det is not None: | |
| detection_score = torch.sigmoid(det).item() | |
| return TruForOutputs( | |
| tamper_map=tamper_map, | |
| confidence_map=confidence_map, | |
| detection_score=detection_score, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _build_model(self) -> torch.nn.Module: | |
| cfg = self._default_config() | |
| model = TruForNetwork(cfg=cfg) | |
| checkpoint = torch.load(self.weights_path, map_location="cpu", weights_only=False) | |
| state_dict = checkpoint.get("state_dict", checkpoint) | |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | |
| if missing: | |
| LOGGER.warning("TruFor missing keys: %s", sorted(missing)) | |
| if unexpected: | |
| LOGGER.warning("TruFor unexpected keys: %s", sorted(unexpected)) | |
| return model | |
| def _prepare_tensor(image: Image.Image) -> torch.Tensor: | |
| rgb = np.asarray(image.convert("RGB"), dtype=np.float32) | |
| tensor = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0) | |
| tensor = tensor / 256.0 # matches the reference implementation | |
| return tensor | |
| class AttrNamespace(dict): | |
| def __getattr__(self, item): | |
| try: | |
| return self[item] | |
| except KeyError as exc: | |
| raise AttributeError(item) from exc | |
| def __setattr__(self, key, value): | |
| self[key] = value | |
| def __contains__(self, item): | |
| return item in self.keys() | |
| def _default_config(cls) -> AttrNamespace: | |
| extra = cls.AttrNamespace( | |
| BACKBONE="mit_b2", | |
| DECODER="MLPDecoder", | |
| DECODER_EMBED_DIM=512, | |
| PREPRC="imagenet", | |
| BN_EPS=0.001, | |
| BN_MOMENTUM=0.1, | |
| DETECTION="confpool", | |
| CONF=True, | |
| NP_WEIGHTS="", | |
| ) | |
| model = cls.AttrNamespace( | |
| NAME="detconfcmx", | |
| MODS=("RGB", "NP++"), | |
| PRETRAINED="", | |
| EXTRA=extra, | |
| ) | |
| dataset = cls.AttrNamespace(NUM_CLASSES=2) | |
| return cls.AttrNamespace(MODEL=model, DATASET=dataset) | |