import gradio as gr import torch import torch.nn.functional as F from facenet_pytorch import MTCNN, InceptionResnetV1 import cv2 from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image from PIL import Image import numpy as np import warnings warnings.filterwarnings("ignore") # Download and Load Model DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' mtcnn = MTCNN( select_largest=False, post_process=False, device=DEVICE ).to(DEVICE).eval() model = InceptionResnetV1( pretrained="vggface2", classify=True, num_classes=1, device=DEVICE ) checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) model.to(DEVICE) model.eval() # Model Inference def predict_frame(frame): """Predict whether the input frame contains a real or fake face""" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame) face = mtcnn(frame_pil) if face is None: return None, None # No face detected # Preprocess the face face = F.interpolate(face.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False) face = face.to(DEVICE, dtype=torch.float32) / 255.0 # Predict with torch.no_grad(): output = torch.sigmoid(model(face).squeeze(0)) prediction = "real" if output.item() < 0.5 else "fake" # Confidence scores real_prediction = 1 - output.item() fake_prediction = output.item() confidences = { 'real': real_prediction, 'fake': fake_prediction } # Visualize target_layers = [model.block8.branch1[-1]] use_cuda = True if torch.cuda.is_available() else False cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda) targets = [ClassifierOutputTarget(0)] grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) grayscale_cam = grayscale_cam[0, :] face_np = face.squeeze(0).permute(1, 2, 0).cpu().numpy() visualization = show_cam_on_image(face_np, grayscale_cam, use_rgb=True) face_with_mask = cv2.addWeighted((face_np * 255).astype(np.uint8), 1, (visualization * 255).astype(np.uint8), 0.5, 0) return prediction, face_with_mask def predict_video(input_video): cap = cv2.VideoCapture(input_video) frames = [] confidences = [] frame_count = 0 skip_frames = 20 while True: ret, frame = cap.read() if not ret: break frame_count+=1 if frame_count % skip_frames != 0: # Skip frames if not divisible by skip_frames continue prediction, frame_with_mask = predict_frame(frame) frames.append(frame_with_mask) confidences.append(prediction) cap.release() # Determine the final prediction based on the maximum occurrence of predictions final_prediction = 'fake' if confidences.count('fake') > confidences.count('real') else 'real' return final_prediction # Gradio Interface interface = gr.Interface( fn=predict_video, inputs=[ gr.Video(label="Input Video") ], outputs=[ gr.Label(label="Class"), ], title="Deep fake video Detection", description="Detect whether the Video is fake or real" ) interface.launch()