varma123's picture
Update app.py
be9e570 verified
raw
history blame contribute delete
No virus
3.45 kB
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings("ignore")
# Download and Load Model
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
mtcnn = MTCNN(
select_largest=False,
post_process=False,
device=DEVICE
).to(DEVICE).eval()
model = InceptionResnetV1(
pretrained="vggface2",
classify=True,
num_classes=1,
device=DEVICE
)
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
# Model Inference
def predict_frame(frame):
"""Predict whether the input frame contains a real or fake face"""
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame)
face = mtcnn(frame_pil)
if face is None:
return None, None # No face detected
# Preprocess the face
face = F.interpolate(face.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
face = face.to(DEVICE, dtype=torch.float32) / 255.0
# Predict
with torch.no_grad():
output = torch.sigmoid(model(face).squeeze(0))
prediction = "real" if output.item() < 0.5 else "fake"
# Confidence scores
real_prediction = 1 - output.item()
fake_prediction = output.item()
confidences = {
'real': real_prediction,
'fake': fake_prediction
}
# Visualize
target_layers = [model.block8.branch1[-1]]
use_cuda = True if torch.cuda.is_available() else False
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
targets = [ClassifierOutputTarget(0)]
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
grayscale_cam = grayscale_cam[0, :]
face_np = face.squeeze(0).permute(1, 2, 0).cpu().numpy()
visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True)
face_with_mask = cv2.addWeighted((face_np * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0)
return prediction, face_with_mask
def predict_video(input_video):
cap = cv2.VideoCapture(input_video)
frames = []
confidences = []
frame_count = 0
skip_frames = 20
while True:
ret, frame = cap.read()
if not ret:
break
frame_count+=1
if frame_count % skip_frames != 0: # Skip frames if not divisible by skip_frames
continue
prediction, frame_with_mask = predict_frame(frame)
frames.append(frame_with_mask)
confidences.append(prediction)
cap.release()
# Determine the final prediction based on the maximum occurrence of predictions
final_prediction = 'fake' if confidences.count('fake') > confidences.count('real') else 'real'
return final_prediction
# Gradio Interface
interface = gr.Interface(
fn=predict_video,
inputs=[
gr.Video(label="Input Video")
],
outputs=[
gr.Label(label="Class"),
],
title="Deep fake video Detection",
description="Detect whether the Video is fake or real"
)
interface.launch()