visinject / utils.py
jeffliulab's picture
Initial Space deployment: Stage 2 fusion demo (CPU, free tier)
e1887f1 verified
"""
Utilities used by app.py.
This is a Space-local subset of the project's `utils.py` — only the helpers
needed for Stage 2 fusion (image I/O, decoder loading, PSNR).
"""
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from decoder import Decoder
def load_image(image_path: str, size: int = 224) -> torch.Tensor:
"""Load an image as a (1, 3, H, W) tensor in [0, 1]."""
img = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
])
return transform(img).unsqueeze(0)
def load_decoder(path: str, embed_dim: int = 512, device: torch.device = None) -> Decoder:
"""Load AnyAttack Decoder weights with state dict key remapping."""
decoder = Decoder(embed_dim=embed_dim).to(device).eval()
ckpt = torch.load(path, map_location="cpu", weights_only=False)
state = ckpt.get("decoder_state_dict", ckpt)
remapped = {}
for k, v in state.items():
k = k.removeprefix("module.")
k = k.replace("upsample_blocks.", "blocks.")
k = k.replace("final_conv.", "head.")
remapped[k] = v
decoder.load_state_dict(remapped)
return decoder
def compute_psnr(img1: torch.Tensor, img2: torch.Tensor) -> float:
"""Compute PSNR between two image tensors in [0, 1]."""
mse = torch.mean((img1 - img2) ** 2).item()
if mse == 0:
return float("inf")
return -10 * torch.log10(torch.tensor(mse)).item()