radames's picture
Update app.py
7fb1c45
raw
history blame contribute delete
No virus
3.46 kB
import os
import sys
from pathlib import Path
import torch
from PIL import Image
import torchvision.transforms as transforms
from networks.drn_seg import DRNSub, DRNSeg
from utils.tools import *
from utils.visualize import *
import gradio as gr
from huggingface_hub import hf_hub_download
def load_classifier(type, model_path, device=torch.device("cpu")):
if type == 'global':
model = DRNSub(1)
elif type == 'local':
model = DRNSeg(2)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
local_model_file = hf_hub_download(
repo_id="radames/FALdetector", filename="local.pth", token=True)
global_model_file = hf_hub_download(
repo_id="radames/FALdetector", filename="global.pth", token=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = load_classifier("global", global_model_file, device)
local_model = load_classifier("local", local_model_file, device)
faces_model_file = 'utils/dlib_face_detector/mmod_human_face_detector.dat'
tf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def predict(img_path):
im_w, im_h = Image.open(img_path).size
faces = face_detection(img_path, verbose=False,
model_file=faces_model_file)
if len(faces) == 0:
raise gr.Error("No face detected by dlib")
face, box = faces[0]
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
with torch.no_grad():
prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
h, w, _ = flow.shape
# Undoing the warps
modified = face.resize((w, h), Image.BICUBIC)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
reverse = Image.fromarray(reverse_np)
# heat map
flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
cv_out = get_heatmap_cv(modified_np, flow_magn, max_flow_mag=7)
heat_map = Image.fromarray(cv_out)
print(prob)
return {"Probability FAL": prob}, modified, heat_map, reverse
with gr.Blocks() as blocks:
gr.Markdown("""
## Unofficial Demo
### Detecting Photoshopped Faces by Scripting Photoshop
#### FAL Detector Live Demo
* https://arxiv.org/abs/1906.05856
* https://peterwang512.github.io/FALdetector/
""")
with gr.Row():
with gr.Column():
in_image = gr.Image(label="Input Image", type="filepath")
run_btn = gr.Button(label="Run")
with gr.Column():
label = gr.Label(
label="Probability being modified by Photoshop FAL")
with gr.Row():
cropped = gr.Image(label="Cropped Input Image")
heatmap = gr.Image(label="Heatmap")
warped = gr.Image(label="Suggested Undo")
run_btn.click(fn=predict, inputs=[in_image], outputs=[
label, cropped, heatmap, warped])
gr.Examples(fn=predict,
examples=list(Path("./examples").glob("*.png")),
inputs=[in_image],
outputs=[label, cropped, heatmap, warped],
cache_examples=True)
blocks.launch()