ProgramSkripsi / app.py
Yuuki0's picture
fix HF
5728d33
import gradio as gr
import torch
import os
import cv2
import urllib.request
from model.pred_func import load_genconvit, df_face, pred_vid, real_or_fake
from model.config import load_config
# --- Model Download ---
def download_models():
"""
Downloads the pre-trained model weights if they don't exist.
"""
weight_dir = 'weight'
ed_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth'
vae_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth'
ed_path = os.path.join(weight_dir, 'genconvit_ed_inference.pth')
vae_path = os.path.join(weight_dir, 'genconvit_vae_inference.pth')
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
if not os.path.exists(ed_path):
print("Downloading ED model weights...")
urllib.request.urlretrieve(ed_url, ed_path)
print("Download complete.")
if not os.path.exists(vae_path):
print("Downloading VAE model weights...")
urllib.request.urlretrieve(vae_url, vae_path)
print("Download complete.")
# --- Global Variables ---
config = load_config()
model = None
def load_model_once():
"""
Loads the model into memory. This function is called once at the start.
"""
global model
if model is None:
download_models()
print("Loading GenConViT model...")
ed_weight = 'genconvit_ed_inference'
vae_weight = 'genconvit_vae_inference'
# Set net='genconvit' to use both ED and VAE as per prediction.py logic for best results
model = load_genconvit(config, net='genconvit', ed_weight=ed_weight, vae_weight=vae_weight, fp16=False)
print("Model loaded successfully.")
def get_video_duration(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return 0
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.release()
if fps == 0:
return 0
return frame_count / fps
# --- Prediction Function ---
def detect_deepfake(video_path, model_type, num_frames):
if video_path is None:
return "❌ Please upload a video file."
# ===== VALIDASI DURASI VIDEO =====
duration = get_video_duration(video_path)
if duration > 60:
return "❌ Video terlalu besar. Durasi maksimal adalah 1 menit (60 detik)."
try:
print(f"Processing video: {video_path} with model: {model_type}")
# Map model_type to internal net identifier
net_mapping = {
"GenConViT": "genconvit",
"AE": "ed",
"VAE": "vae"
}
net_val = net_mapping.get(model_type, "genconvit")
# Extract faces from the video
faces = df_face(video_path, num_frames)
if len(faces) == 0:
return "No faces were detected in the video. Please try another video."
# Make prediction
y, y_val = pred_vid(faces, model, net=net_val)
# Get the label (REAL or FAKE)
label = real_or_fake(y)
# The confidence score y_val is a bit complex in the original code.
# For simplicity, we'll show the raw score associated with the prediction.
# A lower score generally means more likely to be REAL, higher means more likely to be FAKE.
confidence = y_val if label == 'FAKE' else 1 - y_val
return { "FAKE": confidence, "REAL": 1 - confidence }
except Exception as e:
print(f"An error occurred: {e}")
return "An error occurred during processing. The video might be corrupted or in an unsupported format."
# --- Gradio Interface ---
title = "GenConViT: Deepfake Video Detection"
description = """
Upload a video file to detect if it's a deepfake. This application uses the Generative Convolutional Vision Transformer (GenConViT)
to analyze the video. The model achieves an average accuracy of 95.8% and an AUC of 99.3% across multiple datasets.
"""
# Load the model once when the app starts
load_model_once()
iface = gr.Interface(
fn=detect_deepfake,
inputs=[
gr.Video(label="Upload Video"),
gr.Radio(["GenConViT", "AE", "VAE"], label="Pilih Model", value="GenConViT"),
gr.Slider(1, 200, value=15, step=1, label="Number of Frames")
],
outputs=gr.Label(num_top_classes=2, label="Prediction Result"),
title=title,
description=description,
flagging_mode="never",
examples=[
["sample_prediction_data/aajsqyyjni.mp4", "GenConViT", 15],
["sample_prediction_data/anndvqgoko.mp4", "GenConViT", 15],
["sample_prediction_data/0017_fake.mp4.mp4", "GenConViT", 15],
["sample_prediction_data/0048_fake.mp4.mp4", "GenConViT", 15]
]
)
if __name__ == "__main__":
iface.queue().launch()