File size: 2,858 Bytes
a5af557
 
60b5ed2
a5af557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b5ed2
 
 
 
 
 
a5af557
60b5ed2
a5af557
60b5ed2
 
a5af557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60b5ed2
 
 
a5af557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from pathlib import Path

import albumentations as A
import gradio as gr
import numpy as np
import torch
from albumentations.pytorch.functional import img_to_tensor
from huggingface_hub import hf_hub_download
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.utils import draw_segmentation_masks, make_grid, save_image

import utils.misc as misc
from models import get_ensemble_model
from opt import get_opt


def greet(input_image):
    opt, model = _get_model()

    with torch.no_grad():
        image = input_image
        image = np.array(image)
        h, w = image.shape[:2]
        if max(h, w) > 1024:
            transform = A.LongestMaxSize(1024)
        else:
            transform = None

        dsm_image = torch.from_numpy(image).permute(2, 0, 1)

        image_size = image.shape[:2]
        if transform is not None:
            image = transform(image=image)["image"]
        image = img_to_tensor(
            image,
            normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
        )
        image = image.to(opt.device).unsqueeze(0)
        outputs = model(image, seg_size=image_size)
        out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
        pred = outputs["ensemble"]["out_map"].max().item()
        if pred > opt.mask_threshold:
            output_string = f"Found manipulation (manipulation probability {pred:.2f})."
        else:
            output_string = (
                f"No manipulation found (manipulation probability {pred:.2f})."
            )

        if transform is not None:
            output_string += f"\nNote: Image was too large ({h}, {w}) and was resized to fit the model, which may decrease accuracy. We recommend image size smaller than 1024x1024."

        overlay = draw_segmentation_masks(
            dsm_image, masks=out_map[0, ...] > opt.mask_threshold
        )
        overlay = overlay.permute(1, 2, 0)
        overlay = overlay.detach().cpu().numpy()
        overlay = overlay.astype(np.uint8)
    return overlay, output_string


def _get_model(config_path="configs/final.yaml", ckpt_path="tmp/checkpoint.pt"):
    ckpt_path = Path(ckpt_path)
    if not ckpt_path.exists():
        ckpt_path.parent.mkdir(exist_ok=True, parents=True)
        hf_hub_download(
            repo_id="yhzhai/WSCL",
            filename="checkpoint.pt",
            local_dir=ckpt_path.parent.as_posix(),
        )

    opt = get_opt(config_path)
    opt.resume = ckpt_path.as_posix()

    model = get_ensemble_model(opt).to(opt.device)
    misc.resume_from(model, opt.resume)
    return opt, model


iface = gr.Interface(
    fn=greet,
    title="WSCL: Image Manipulation Detection",
    inputs=gr.Image(),
    outputs=["image", "text"],
    examples=[["demo/au.jpg"], ["demo/tp.jpg"]],
    cache_examples=True,
)
iface.launch()