Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import os | |
| import cv2 | |
| import urllib.request | |
| from model.pred_func import load_genconvit, df_face, pred_vid, real_or_fake | |
| from model.config import load_config | |
| # --- Model Download --- | |
| def download_models(): | |
| """ | |
| Downloads the pre-trained model weights if they don't exist. | |
| """ | |
| weight_dir = 'weight' | |
| ed_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth' | |
| vae_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth' | |
| ed_path = os.path.join(weight_dir, 'genconvit_ed_inference.pth') | |
| vae_path = os.path.join(weight_dir, 'genconvit_vae_inference.pth') | |
| if not os.path.exists(weight_dir): | |
| os.makedirs(weight_dir) | |
| if not os.path.exists(ed_path): | |
| print("Downloading ED model weights...") | |
| urllib.request.urlretrieve(ed_url, ed_path) | |
| print("Download complete.") | |
| if not os.path.exists(vae_path): | |
| print("Downloading VAE model weights...") | |
| urllib.request.urlretrieve(vae_url, vae_path) | |
| print("Download complete.") | |
| # --- Global Variables --- | |
| config = load_config() | |
| model = None | |
| def load_model_once(): | |
| """ | |
| Loads the model into memory. This function is called once at the start. | |
| """ | |
| global model | |
| if model is None: | |
| download_models() | |
| print("Loading GenConViT model...") | |
| ed_weight = 'genconvit_ed_inference' | |
| vae_weight = 'genconvit_vae_inference' | |
| # Set net='genconvit' to use both ED and VAE as per prediction.py logic for best results | |
| model = load_genconvit(config, net='genconvit', ed_weight=ed_weight, vae_weight=vae_weight, fp16=False) | |
| print("Model loaded successfully.") | |
| def get_video_duration(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return 0 | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| cap.release() | |
| if fps == 0: | |
| return 0 | |
| return frame_count / fps | |
| # --- Prediction Function --- | |
| def detect_deepfake(video_path, model_type, num_frames): | |
| if video_path is None: | |
| return "❌ Please upload a video file." | |
| # ===== VALIDASI DURASI VIDEO ===== | |
| duration = get_video_duration(video_path) | |
| if duration > 60: | |
| return "❌ Video terlalu besar. Durasi maksimal adalah 1 menit (60 detik)." | |
| try: | |
| print(f"Processing video: {video_path} with model: {model_type}") | |
| # Map model_type to internal net identifier | |
| net_mapping = { | |
| "GenConViT": "genconvit", | |
| "AE": "ed", | |
| "VAE": "vae" | |
| } | |
| net_val = net_mapping.get(model_type, "genconvit") | |
| # Extract faces from the video | |
| faces = df_face(video_path, num_frames) | |
| if len(faces) == 0: | |
| return "No faces were detected in the video. Please try another video." | |
| # Make prediction | |
| y, y_val = pred_vid(faces, model, net=net_val) | |
| # Get the label (REAL or FAKE) | |
| label = real_or_fake(y) | |
| # The confidence score y_val is a bit complex in the original code. | |
| # For simplicity, we'll show the raw score associated with the prediction. | |
| # A lower score generally means more likely to be REAL, higher means more likely to be FAKE. | |
| confidence = y_val if label == 'FAKE' else 1 - y_val | |
| return { "FAKE": confidence, "REAL": 1 - confidence } | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return "An error occurred during processing. The video might be corrupted or in an unsupported format." | |
| # --- Gradio Interface --- | |
| title = "GenConViT: Deepfake Video Detection" | |
| description = """ | |
| Upload a video file to detect if it's a deepfake. This application uses the Generative Convolutional Vision Transformer (GenConViT) | |
| to analyze the video. The model achieves an average accuracy of 95.8% and an AUC of 99.3% across multiple datasets. | |
| """ | |
| # Load the model once when the app starts | |
| load_model_once() | |
| iface = gr.Interface( | |
| fn=detect_deepfake, | |
| inputs=[ | |
| gr.Video(label="Upload Video"), | |
| gr.Radio(["GenConViT", "AE", "VAE"], label="Pilih Model", value="GenConViT"), | |
| gr.Slider(1, 200, value=15, step=1, label="Number of Frames") | |
| ], | |
| outputs=gr.Label(num_top_classes=2, label="Prediction Result"), | |
| title=title, | |
| description=description, | |
| flagging_mode="never", | |
| examples=[ | |
| ["sample_prediction_data/aajsqyyjni.mp4", "GenConViT", 15], | |
| ["sample_prediction_data/anndvqgoko.mp4", "GenConViT", 15], | |
| ["sample_prediction_data/0017_fake.mp4.mp4", "GenConViT", 15], | |
| ["sample_prediction_data/0048_fake.mp4.mp4", "GenConViT", 15] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| iface.queue().launch() | |