import gradio as gr import warnings import cv2 import dlib from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget import numpy as np import torch from retinaface.pre_trained_models import get_model from Scripts.model import create_cam, create_model from Scripts.preprocess import crop_face, extract_face, extract_frames from Scripts.ca_generator import get_augs import spaces warnings.filterwarnings('ignore') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') sbcl = create_model("Weights/weights.tar") face_detector = get_model("resnet50_2020-07-20", max_size=1024, device=device) face_detector.eval() cam_sbcl = create_cam(sbcl) targets = [ClassifierOutputTarget(1)] # Examples examples = ["Examples/Fake/Fake1.PNG", "Examples/Real/Real1.PNG", "Examples/Real/Real2.PNG", "Examples/Fake/Fake3.PNG", "Examples/Fake/Fake2.PNG", ] examples_videos = ['Examples/Fake1.mp4', 'Examples/Real1.mp4'] # dlib Models dlib_face_detector = dlib.get_frontal_face_detector() dlib_face_predictor = dlib.shape_predictor( 'Weights/shape_predictor_81_face_landmarks.dat') @spaces.GPU def predict_image(inp): face_list = extract_face(inp, face_detector) if len(face_list) == 0: return {'No face detected!': 1}, None with torch.no_grad(): img = torch.tensor(face_list).to(device).float()/255 pred = sbcl(img).softmax(1)[:, 1].cpu().data.numpy().tolist()[0] confidences = {'Real': 1-pred, 'Fake': pred} grayscale_cam = cam_sbcl(input_tensor=img, targets=targets, aug_smooth=True) grayscale_cam = grayscale_cam[0, :] cam_image = show_cam_on_image(face_list[0].transpose(1, 2, 0)/255, grayscale_cam, use_rgb=True) return confidences, cam_image @spaces.GPU def predict_video(inp): face_list, idx_list = extract_frames(inp, 10, face_detector) with torch.no_grad(): img = torch.tensor(face_list).to(device).float()/255 pred = sbcl(img).softmax(1)[:, 1] pred_list = [] idx_img = -1 for i in range(len(pred)): if idx_list[i] != idx_img: pred_list.append([]) idx_img = idx_list[i] pred_list[-1].append(pred[i].item()) pred_res = np.zeros(len(pred_list)) for i in range(len(pred_res)): pred_res[i] = max(pred_list[i]) pred = pred_res.mean() most_fake = np.argmax(pred_res) grayscale_cam = cam_sbcl(input_tensor=img[most_fake].unsqueeze(0), targets=targets, aug_smooth=True) grayscale_cam = grayscale_cam[0, :] cam_image = show_cam_on_image(face_list[most_fake].transpose(1, 2, 0)/255, grayscale_cam, use_rgb=True) return {'Real': 1-pred, 'Fake': pred}, cam_image with gr.Blocks(title="Deepfake Detection CL", theme='upsatwal/mlsc_tiet', css=""" @import url('https://fonts.googleapis.com/css?family=Source+Code+Pro:200'); #custom_header { min-height: 3rem; background-image: url('https://static.pexels.com/photos/414171/pexels-photo-414171.jpeg'); background-size: cover; background-position: top; color: white; text-align: center; padding: 0.5rem; font-family: 'Source Code Pro', monospace; text-transform: uppercase; } #custom_header:hover { -webkit-animation: slidein 10s; animation: slidein 10s; -webkit-animation-fill-mode: forwards; animation-fill-mode: forwards; -webkit-animation-iteration-count: infinite; animation-iteration-count: infinite; -webkit-animation-direction: alternate; animation-direction: alternate; } @-webkit-keyframes slidein { from { background-position: top; background-size: 3000px; } to { background-position: -100px 0px; background-size: 2750px; } } @keyframes slidein { from { background-position: top; background-size: 3000px; } to { background-position: -100px 0px; background-size: 2750px; } } #custom_title { min-height: 3rem; text-align: center; } .full-width { width: 100%; } .full-width:hover { background: rgba(75, 75, 250, 0.3); color: white; } """) as demo: with gr.Tab("Image"): with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown("## Deepfake Detection", elem_id="custom_header") input_image = gr.Image(label="Input Image", height=240) btn = gr.Button(value="Submit", variant="primary", elem_classes="full-width") with gr.Column(): with gr.Group(): gr.Markdown("## Result", elem_id="custom_header") output_image = gr.Image(label="GradCAM Image", height=240) label_probs = gr.Label() gr.Examples( examples=examples, inputs=input_image, outputs=output_image, fn=predict_image, cache_examples=False, ) btn.click(predict_image, inputs=input_image, outputs=[label_probs, output_image], api_name="/predict_image") with gr.Tab("Video"): with gr.Row(): with gr.Column(): with gr.Group(): gr.Markdown("## Deepfake Detection", elem_id="custom_header") input_video = gr.Video(label="Input Video", height=240) btn_video = gr.Button(value="Submit", variant="primary", elem_classes="full-width") with gr.Column(): with gr.Group(): gr.Markdown("## Result", elem_id="custom_header") output_image_video = gr.Image(label="GradCAM", height=240) label_probs_video = gr.Label() gr.Examples( examples=examples_videos, inputs=input_video, outputs=output_image_video, fn=predict_video, cache_examples=False, ) btn_video.click(predict_video, inputs=input_video, outputs=[label_probs_video, output_image_video], api_name="/predict_video") if __name__ == "__main__": demo.launch()