|
from pathlib import Path |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from albumentations.pytorch.functional import img_to_tensor |
|
from huggingface_hub import hf_hub_download |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torchvision.utils import draw_segmentation_masks, make_grid, save_image |
|
|
|
import utils.misc as misc |
|
from models import get_ensemble_model |
|
from opt import get_opt |
|
|
|
|
|
def greet(input_image): |
|
opt, model = _get_model() |
|
|
|
with torch.no_grad(): |
|
image = input_image |
|
image = np.array(image) |
|
dsm_image = torch.from_numpy(image).permute(2, 0, 1) |
|
image_size = image.shape[:2] |
|
image = img_to_tensor( |
|
image, |
|
normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD}, |
|
) |
|
image = image.to(opt.device).unsqueeze(0) |
|
outputs = model(image, seg_size=image_size) |
|
out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu() |
|
pred = outputs["ensemble"]["out_map"].max().item() |
|
if pred > opt.mask_threshold: |
|
output_string = f"Found manipulation (manipulation probability {pred:.2f})." |
|
else: |
|
output_string = ( |
|
f"No manipulation found (manipulation probability {pred:.2f})." |
|
) |
|
|
|
overlay = draw_segmentation_masks( |
|
dsm_image, masks=out_map[0, ...] > opt.mask_threshold |
|
) |
|
overlay = overlay.permute(1, 2, 0) |
|
overlay = overlay.detach().cpu().numpy() |
|
overlay = overlay.astype(np.uint8) |
|
return overlay, output_string |
|
|
|
|
|
def _get_model(config_path="configs/final.yaml", ckpt_path="tmp/checkpoint.pt"): |
|
ckpt_path = Path(ckpt_path) |
|
if not ckpt_path.exists(): |
|
ckpt_path.parent.mkdir(exist_ok=True, parents=True) |
|
hf_hub_download( |
|
repo_id="yhzhai/WSCL", |
|
filename="checkpoint.pt", |
|
local_dir=ckpt_path.parent.as_posix(), |
|
) |
|
|
|
opt = get_opt(config_path) |
|
opt.resume = ckpt_path.as_posix() |
|
|
|
model = get_ensemble_model(opt).to(opt.device) |
|
misc.resume_from(model, opt.resume) |
|
return opt, model |
|
|
|
|
|
iface = gr.Interface( |
|
fn=greet, |
|
title="WSCL: Image Manipulation Detection", |
|
inputs=gr.Image(), |
|
outputs=["image", "text"], |
|
examples=[["demo/au.jpg"], ["demo/tp.jpg"]], |
|
cache_examples=True, |
|
) |
|
iface.launch() |
|
|