dhairyashah's picture
Update app.py
510c41e verified
import spaces
import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import os
import numpy as np
from PIL import Image
import zipfile
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import tempfile
with zipfile.ZipFile("examples.zip","r") as zip_ref:
zip_ref.extractall(".")
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
mtcnn = MTCNN(
select_largest=False,
post_process=False,
device=DEVICE
).to(DEVICE).eval()
model = InceptionResnetV1(
pretrained="vggface2",
classify=True,
num_classes=1,
device=DEVICE
)
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
EXAMPLES_FOLDER = 'examples'
examples_names = os.listdir(EXAMPLES_FOLDER)
examples = []
for example_name in examples_names:
example_path = os.path.join(EXAMPLES_FOLDER, example_name)
label = example_name.split('_')[0]
example = {
'path': example_path,
'label': label
}
examples.append(example)
np.random.shuffle(examples) # shuffle
@spaces.GPU
def process_frame(frame, mtcnn, model, cam, targets):
face = mtcnn(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
if face is None:
return frame, None, None
face = face.unsqueeze(0)
face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
face = face.to(DEVICE)
face = face.to(torch.float32)
face = face / 255.0
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
with torch.no_grad():
output = torch.sigmoid(model(face).squeeze(0))
prediction = "real" if output.item() < 0.5 else "fake"
confidence = 1 - output.item() if prediction == "real" else output.item()
return visualization, prediction, confidence
@spaces.GPU
def predict_video(input_video: str):
"""Predict the labels for each frame of the input video"""
cap = cv2.VideoCapture(input_video)
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
target_layers = [model.block8.branch1[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(0)]
temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
processed_frame, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets)
if processed_frame is not None:
# Resize the processed frame to match the original video dimensions
processed_frame = cv2.resize(processed_frame, (width, height))
# Add text with prediction and confidence
if prediction is not None and confidence is not None:
text = f"{prediction}: {confidence:.2f}"
else:
text = "No prediction available"
cv2.putText(processed_frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
out.write(processed_frame)
else:
# If no face is detected, write the original frame
out.write(frame)
cap.release()
out.release()
return temp_output.name
interface = gr.Interface(
fn=predict_video,
inputs=[
gr.Video(label="Input Video")
],
outputs=[
gr.Video(label="Output Video")
],
title="Video Deepfake Detection",
description="Upload a video to detect deepfakes in each frame."
)
if __name__ == "__main__":
interface.launch()