File size: 3,365 Bytes
c61b6f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import sys
from pathlib import Path
import torch
from PIL import Image
import torchvision.transforms as transforms

from networks.drn_seg import DRNSub, DRNSeg
from utils.tools import *
from utils.visualize import *

import gradio as gr
from huggingface_hub import hf_hub_download


def load_classifier(type, model_path, device=torch.device("cpu")):
    if type == 'global':
        model = DRNSub(1)
    elif type == 'local':
        model = DRNSeg(2)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict['model'])
    model.to(device)
    model.device = device
    model.eval()
    return model


local_model_file = hf_hub_download(
    repo_id="radames/FALdetector", filename="local.pth", token=True)
global_model_file = hf_hub_download(
    repo_id="radames/FALdetector", filename="global.pth", token=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = load_classifier("global", global_model_file, device)
local_model = load_classifier("local", local_model_file, device)
faces_model_file = 'utils/dlib_face_detector/mmod_human_face_detector.dat'

tf = transforms.Compose([transforms.ToTensor(),
                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                              std=[0.229, 0.224, 0.225])])


def predict(img_path):
    im_w, im_h = Image.open(img_path).size
    faces = face_detection(img_path, verbose=False,
                           model_file=faces_model_file)
    if len(faces) == 0:
        raise gr.Error("No face detected by dlib")

    face, box = faces[0]
    face = resize_shorter_side(face, 400)[0]
    face_tens = tf(face).to(device)

    with torch.no_grad():
        prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
        flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy()
        flow = np.transpose(flow, (1, 2, 0))
        h, w, _ = flow.shape

    # Undoing the warps
    modified = face.resize((w, h), Image.BICUBIC)
    modified_np = np.asarray(modified)
    reverse_np = warp(modified_np, flow)
    reverse = Image.fromarray(reverse_np)
    # heat map
    flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
    cv_out = get_heatmap_cv(modified_np, flow_magn, max_flow_mag=7)
    heat_map = Image.fromarray(cv_out)

    print(prob)
    return {"Probability FAL": prob}, modified, heat_map, reverse


with gr.Blocks() as blocks:
    gr.Markdown("""
## Unofficial Demo
### Detecting Photoshopped Faces by Scripting Photoshop

* https://arxiv.org/abs/1906.05856
* https://peterwang512.github.io/FALdetector/
    """)
    with gr.Row():
        with gr.Column():
            in_image = gr.Image(label="Input Image", type="filepath")
            run_btn = gr.Button(label="Run")
        with gr.Column():
            label = gr.Label(
                label="Probability being modified by Photoshop FAL")
    with gr.Row():
        cropped = gr.Image(label="Cropped Input Image")
        heatmap = gr.Image(label="Heatmap")
        warped = gr.Image(label="Suggested Undo")
    run_btn.click(fn=predict, inputs=[in_image], outputs=[
                  label, cropped, heatmap, warped])
    gr.Examples(fn=predict, examples=list(Path("./examples").glob("*.png")), inputs=[in_image], outputs=[
                label, cropped, heatmap, warped])
blocks.launch()