Yuanhao Zhai commited on
Commit
a5af557
1 Parent(s): 2f7bb6a

add gradio app

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +74 -0
  3. demo.py +4 -3
README.md CHANGED
@@ -74,7 +74,7 @@ python demo.py --load configs/final.yaml --resume checkpoint-path
74
  ```
75
 
76
  By default, it evaluates all `.jpg` files in the `demo` folder, and saves the
77
- detection result in `tmp`.
78
 
79
 
80
 
 
74
  ```
75
 
76
  By default, it evaluates all `.jpg` files in the `demo` folder, and saves the
77
+ detection result in `tmp`, with manipulation probablities appended to the file names.
78
 
79
 
80
 
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from albumentations.pytorch.functional import img_to_tensor
7
+ from huggingface_hub import hf_hub_download
8
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9
+ from torchvision.utils import draw_segmentation_masks, make_grid, save_image
10
+
11
+ import utils.misc as misc
12
+ from models import get_ensemble_model
13
+ from opt import get_opt
14
+
15
+
16
+ def greet(input_image):
17
+ opt, model = _get_model()
18
+
19
+ with torch.no_grad():
20
+ image = input_image
21
+ image = np.array(image)
22
+ dsm_image = torch.from_numpy(image).permute(2, 0, 1)
23
+ image_size = image.shape[:2]
24
+ image = img_to_tensor(
25
+ image,
26
+ normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
27
+ )
28
+ image = image.to(opt.device).unsqueeze(0)
29
+ outputs = model(image, seg_size=image_size)
30
+ out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
31
+ pred = outputs["ensemble"]["out_map"].max().item()
32
+ if pred > opt.mask_threshold:
33
+ output_string = f"Found manipulation (manipulation probability {pred:.2f})."
34
+ else:
35
+ output_string = (
36
+ f"No manipulation found (manipulation probability {pred:.2f})."
37
+ )
38
+
39
+ overlay = draw_segmentation_masks(
40
+ dsm_image, masks=out_map[0, ...] > opt.mask_threshold
41
+ )
42
+ overlay = overlay.permute(1, 2, 0)
43
+ overlay = overlay.detach().cpu().numpy()
44
+ overlay = overlay.astype(np.uint8)
45
+ return overlay, output_string
46
+
47
+
48
+ def _get_model(config_path="configs/final.yaml", ckpt_path="tmp/checkpoint.pt"):
49
+ ckpt_path = Path(ckpt_path)
50
+ if not ckpt_path.exists():
51
+ ckpt_path.parent.mkdir(exist_ok=True, parents=True)
52
+ hf_hub_download(
53
+ repo_id="yhzhai/WSCL",
54
+ filename="checkpoint.pt",
55
+ local_dir=ckpt_path.parent.as_posix(),
56
+ )
57
+
58
+ opt = get_opt(config_path)
59
+ opt.resume = ckpt_path.as_posix()
60
+
61
+ model = get_ensemble_model(opt).to(opt.device)
62
+ misc.resume_from(model, opt.resume)
63
+ return opt, model
64
+
65
+
66
+ iface = gr.Interface(
67
+ fn=greet,
68
+ title="WSCL: Image Manipulation Detection",
69
+ inputs=gr.Image(),
70
+ outputs=["image", "text"],
71
+ examples=[["demo/au.jpg"], ["demo/tp.jpg"]],
72
+ cache_examples=True,
73
+ )
74
+ iface.launch()
demo.py CHANGED
@@ -1,9 +1,10 @@
 
 
1
  import albumentations as A
2
  import cv2
3
  import torch
4
  import tqdm
5
  from albumentations.pytorch.functional import img_to_tensor
6
- from pathlib import Path
7
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8
  from torchvision.utils import draw_segmentation_masks, make_grid, save_image
9
 
@@ -31,8 +32,8 @@ def demo(folder_path, output_path=Path("tmp")):
31
  image = image.to(opt.device).unsqueeze(0)
32
  outputs = model(image, seg_size=image_size)
33
  out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
34
- pred = outputs["ensemble"]["out_map"].max().item()
35
- if pred > opt.mask_threshold:
36
  print(f"Found manipulation in {image_path.name}")
37
  else:
38
  print(f"No manipulation found in {image_path.name}")
 
1
+ from pathlib import Path
2
+
3
  import albumentations as A
4
  import cv2
5
  import torch
6
  import tqdm
7
  from albumentations.pytorch.functional import img_to_tensor
 
8
  from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9
  from torchvision.utils import draw_segmentation_masks, make_grid, save_image
10
 
 
32
  image = image.to(opt.device).unsqueeze(0)
33
  outputs = model(image, seg_size=image_size)
34
  out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
35
+ pred = outputs["ensemble"]["out_map"].max().item()
36
+ if pred > opt.mask_threshold:
37
  print(f"Found manipulation in {image_path.name}")
38
  else:
39
  print(f"No manipulation found in {image_path.name}")