Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from facenet_pytorch import MTCNN, InceptionResnetV1 | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import zipfile | |
| import cv2 | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| import tempfile | |
| with zipfile.ZipFile("examples.zip","r") as zip_ref: | |
| zip_ref.extractall(".") | |
| DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| mtcnn = MTCNN( | |
| select_largest=False, | |
| post_process=False, | |
| device=DEVICE | |
| ).to(DEVICE).eval() | |
| model = InceptionResnetV1( | |
| pretrained="vggface2", | |
| classify=True, | |
| num_classes=1, | |
| device=DEVICE | |
| ) | |
| checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(DEVICE) | |
| model.eval() | |
| EXAMPLES_FOLDER = 'examples' | |
| examples_names = os.listdir(EXAMPLES_FOLDER) | |
| examples = [] | |
| for example_name in examples_names: | |
| example_path = os.path.join(EXAMPLES_FOLDER, example_name) | |
| label = example_name.split('_')[0] | |
| example = { | |
| 'path': example_path, | |
| 'label': label | |
| } | |
| examples.append(example) | |
| np.random.shuffle(examples) # shuffle | |
| def process_frame(frame, mtcnn, model, cam, targets): | |
| face = mtcnn(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
| if face is None: | |
| return frame, None, None | |
| face = face.unsqueeze(0) | |
| face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) | |
| face = face.to(DEVICE) | |
| face = face.to(torch.float32) | |
| face = face / 255.0 | |
| face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy() | |
| grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True) | |
| grayscale_cam = grayscale_cam[0, :] | |
| visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True) | |
| with torch.no_grad(): | |
| output = torch.sigmoid(model(face).squeeze(0)) | |
| prediction = "real" if output.item() < 0.5 else "fake" | |
| confidence = 1 - output.item() if prediction == "real" else output.item() | |
| return visualization, prediction, confidence | |
| def predict_video(input_video: str): | |
| """Predict the labels for each frame of the input video""" | |
| cap = cv2.VideoCapture(input_video) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| target_layers = [model.block8.branch1[-1]] | |
| cam = GradCAM(model=model, target_layers=target_layers) | |
| targets = [ClassifierOutputTarget(0)] | |
| temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
| out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| processed_frame, prediction, confidence = process_frame(frame, mtcnn, model, cam, targets) | |
| if processed_frame is not None: | |
| # Resize the processed frame to match the original video dimensions | |
| processed_frame = cv2.resize(processed_frame, (width, height)) | |
| # Add text with prediction and confidence | |
| if prediction is not None and confidence is not None: | |
| text = f"{prediction}: {confidence:.2f}" | |
| else: | |
| text = "No prediction available" | |
| cv2.putText(processed_frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) | |
| out.write(processed_frame) | |
| else: | |
| # If no face is detected, write the original frame | |
| out.write(frame) | |
| cap.release() | |
| out.release() | |
| return temp_output.name | |
| interface = gr.Interface( | |
| fn=predict_video, | |
| inputs=[ | |
| gr.Video(label="Input Video") | |
| ], | |
| outputs=[ | |
| gr.Video(label="Output Video") | |
| ], | |
| title="Video Deepfake Detection", | |
| description="Upload a video to detect deepfakes in each frame." | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |