Jatin-tec
Add application file
65d7391
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from PIL import Image
from .models.cmx.builder_np_conf import myEncoderDecoder as TruForNetwork
LOGGER = logging.getLogger(__name__)
@dataclass(frozen=True)
class TruForOutputs:
"""Lightweight container for TruFor inference outputs."""
tamper_map: np.ndarray
confidence_map: Optional[np.ndarray]
detection_score: Optional[float]
class TruForBundledModel:
"""Loads the TruFor network from the vendored sources and runs inference."""
def __init__(self, weights_path: Path | str, device: str = "cpu") -> None:
self.weights_path = Path(weights_path)
if not self.weights_path.exists():
raise FileNotFoundError(f"TruFor weights missing at {self.weights_path}")
try:
self.device = torch.device(device)
except RuntimeError as exc: # pragma: no cover - defensive path for invalid strings
raise ValueError(f"Unsupported torch device '{device}'") from exc
self.model = self._build_model().to(self.device)
self.model.eval()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def predict(self, image: Image.Image) -> TruForOutputs:
if image is None:
raise ValueError("An input image is required for TruFor inference.")
tensor = self._prepare_tensor(image).to(self.device)
with torch.inference_mode():
pred, conf, det, _ = self.model(tensor)
tamper_map = torch.softmax(pred[0], dim=0)[1].cpu().numpy()
confidence_map: Optional[np.ndarray] = None
if conf is not None:
confidence_map = torch.sigmoid(conf[0][0]).cpu().numpy()
detection_score: Optional[float] = None
if det is not None:
detection_score = torch.sigmoid(det).item()
return TruForOutputs(
tamper_map=tamper_map,
confidence_map=confidence_map,
detection_score=detection_score,
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_model(self) -> torch.nn.Module:
cfg = self._default_config()
model = TruForNetwork(cfg=cfg)
checkpoint = torch.load(self.weights_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("state_dict", checkpoint)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
LOGGER.warning("TruFor missing keys: %s", sorted(missing))
if unexpected:
LOGGER.warning("TruFor unexpected keys: %s", sorted(unexpected))
return model
@staticmethod
def _prepare_tensor(image: Image.Image) -> torch.Tensor:
rgb = np.asarray(image.convert("RGB"), dtype=np.float32)
tensor = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0)
tensor = tensor / 256.0 # matches the reference implementation
return tensor
class AttrNamespace(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError as exc:
raise AttributeError(item) from exc
def __setattr__(self, key, value):
self[key] = value
def __contains__(self, item):
return item in self.keys()
@classmethod
def _default_config(cls) -> AttrNamespace:
extra = cls.AttrNamespace(
BACKBONE="mit_b2",
DECODER="MLPDecoder",
DECODER_EMBED_DIM=512,
PREPRC="imagenet",
BN_EPS=0.001,
BN_MOMENTUM=0.1,
DETECTION="confpool",
CONF=True,
NP_WEIGHTS="",
)
model = cls.AttrNamespace(
NAME="detconfcmx",
MODS=("RGB", "NP++"),
PRETRAINED="",
EXTRA=extra,
)
dataset = cls.AttrNamespace(NUM_CLASSES=2)
return cls.AttrNamespace(MODEL=model, DATASET=dataset)