DFDetection / app.py
SixOpen's picture
Update app.py
e9a64cd verified
raw
history blame contribute delete
No virus
6.4 kB
import gradio as gr
import warnings
import cv2
import dlib
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import numpy as np
import torch
from retinaface.pre_trained_models import get_model
from Scripts.model import create_cam, create_model
from Scripts.preprocess import crop_face, extract_face, extract_frames
from Scripts.ca_generator import get_augs
import spaces
warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sbcl = create_model("Weights/weights.tar")
sbcl.to(device)
face_detector = get_model("resnet50_2020-07-20", max_size=1024, device=device)
face_detector.eval()
cam_sbcl = create_cam(sbcl)
targets = [ClassifierOutputTarget(1)]
# Examples
examples = ["Examples/Fake/Fake1.PNG", "Examples/Real/Real1.PNG", "Examples/Real/Real2.PNG", "Examples/Fake/Fake3.PNG",
"Examples/Fake/Fake2.PNG", ]
examples_videos = ['Examples/Fake1.mp4', 'Examples/Real1.mp4']
# dlib Models
dlib_face_detector = dlib.get_frontal_face_detector()
dlib_face_predictor = dlib.shape_predictor(
'Weights/shape_predictor_81_face_landmarks.dat')
@spaces.GPU
def predict_image(inp):
face_list = extract_face(inp, face_detector)
if len(face_list) == 0:
return {'No face detected!': 1}, None
with torch.no_grad():
img = torch.tensor(face_list).to(device).float() / 255
pred = sbcl(img).softmax(1)[:, 1].cpu().data.numpy().tolist()[0]
confidences = {'Real': 1 - pred, 'Fake': pred}
grayscale_cam = cam_sbcl(input_tensor=img, targets=targets, aug_smooth=True)
grayscale_cam = grayscale_cam[0, :]
cam_image = show_cam_on_image(face_list[0].transpose(1, 2, 0) / 255, grayscale_cam, use_rgb=True)
return confidences, cam_image
@spaces.GPU
def predict_video(inp):
face_list, idx_list = extract_frames(inp, 10, face_detector)
with torch.no_grad():
img = torch.tensor(face_list).to(device).float() / 255
pred = sbcl(img).softmax(1)[:, 1]
pred_list = []
idx_img = -1
for i in range(len(pred)):
if idx_list[i] != idx_img:
pred_list.append([])
idx_img = idx_list[i]
pred_list[-1].append(pred[i].item())
pred_res = np.zeros(len(pred_list))
for i in range(len(pred_res)):
pred_res[i] = max(pred_list[i])
pred = pred_res.mean()
most_fake = np.argmax(pred_res)
grayscale_cam = cam_sbcl(input_tensor=img[most_fake].unsqueeze(0), targets=targets, aug_smooth=True)
grayscale_cam = grayscale_cam[0, :]
cam_image = show_cam_on_image(face_list[most_fake].transpose(1, 2, 0) / 255, grayscale_cam, use_rgb=True)
return {'Real': 1 - pred, 'Fake': pred}, cam_image
with gr.Blocks(title="Deepfake Detection CL", theme='upsatwal/mlsc_tiet', css="""
@import url('https://fonts.googleapis.com/css?family=Source+Code+Pro:200');
#custom_header {
min-height: 3rem;
background-image: url('https://static.pexels.com/photos/414171/pexels-photo-414171.jpeg');
background-size: cover;
background-position: top;
color: white;
text-align: center;
padding: 0.5rem;
font-family: 'Source Code Pro', monospace;
text-transform: uppercase;
}
#custom_header:hover {
-webkit-animation: slidein 10s;
animation: slidein 10s;
-webkit-animation-fill-mode: forwards;
animation-fill-mode: forwards;
-webkit-animation-iteration-count: infinite;
animation-iteration-count: infinite;
-webkit-animation-direction: alternate;
animation-direction: alternate;
}
@-webkit-keyframes slidein {
from {
background-position: top;
background-size: 3000px;
}
to {
background-position: -100px 0px;
background-size: 2750px;
}
}
@keyframes slidein {
from {
background-position: top;
background-size: 3000px;
}
to {
background-position: -100px 0px;
background-size: 2750px;
}
}
#custom_title {
min-height: 3rem;
text-align: center;
}
.full-width {
width: 100%;
}
.full-width:hover {
background: rgba(75, 75, 250, 0.3);
color: white;
}
""") as demo:
with gr.Tab("Image"):
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown("## Deepfake Detection", elem_id="custom_header")
input_image = gr.Image(label="Input Image", height=240)
btn = gr.Button(value="Submit", variant="primary", elem_classes="full-width")
with gr.Column():
with gr.Group():
gr.Markdown("## Result", elem_id="custom_header")
output_image = gr.Image(label="GradCAM Image", height=240)
label_probs = gr.Label()
gr.Examples(
examples=examples,
inputs=input_image,
outputs=[label_probs, output_image],
fn=predict_image,
cache_examples=True,
)
btn.click(predict_image, inputs=input_image, outputs=[label_probs, output_image], api_name="/predict_image")
with gr.Tab("Video"):
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown("## Deepfake Detection", elem_id="custom_header")
input_video = gr.Video(label="Input Video", height=240)
btn_video = gr.Button(value="Submit", variant="primary", elem_classes="full-width")
with gr.Column():
with gr.Group():
gr.Markdown("## Result", elem_id="custom_header")
output_image_video = gr.Image(label="GradCAM", height=240)
label_probs_video = gr.Label()
gr.Examples(
examples=examples_videos,
inputs=input_video,
outputs=[label_probs_video, output_image_video],
fn=predict_video,
cache_examples=True,
)
btn_video.click(predict_video, inputs=input_video, outputs=[label_probs_video, output_image_video], api_name="/predict_video")
if __name__ == "__main__":
demo.launch()