import streamlit as st import os import tempfile import torch from huggingface_hub import hf_hub_download import cv2 from PIL import Image import numpy as np import time import sys import json # Add a custom path for model imports if "model" not in sys.path: sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Import your prediction functions from model.pred_func import ( load_genconvit, df_face, pred_vid, real_or_fake, set_result, store_result ) from model.config import load_config # Set page config st.set_page_config( page_title="Deepfake Detection with GenConViT", page_icon="🎭", layout="wide" ) @st.cache_resource def load_model_from_huggingface(): """Load the model weights from Hugging Face Hub""" config = load_config() # Download model weights with st.spinner("Downloading model weights from Hugging Face Hub..."): ed_path = hf_hub_download( repo_id="Deressa/GenConViT", filename="genconvit_ed_inference.pth", ) vae_path = hf_hub_download( repo_id="Deressa/GenConViT", filename="genconvit_vae_inference.pth", ) # Load the model with st.spinner("Loading model..."): model = load_genconvit( config, "genconvit", ed_path, vae_path, fp16=False ) return model, config def is_video(file): """Check if a file is a valid video file""" try: cap = cv2.VideoCapture(file) if not cap.isOpened(): return False ret, frame = cap.read() cap.release() return ret except: return False def process_video(video_file, model, config, num_frames=15): """Process a video file and return prediction""" # Save uploaded file to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file: tmp_file.write(video_file.read()) tmp_file_path = tmp_file.name try: # Process the video with st.spinner("Extracting faces from video frames..."): df = df_face(tmp_file_path, num_frames, "genconvit") if len(df) >= 1: with st.spinner("Analyzing video..."): y, y_val = pred_vid(df, model) prediction = real_or_fake(y) confidence = float(y_val) else: prediction = "Unable to detect faces" confidence = 0.0 # Clean up the temporary file os.unlink(tmp_file_path) return prediction, confidence except Exception as e: # Clean up the temporary file os.unlink(tmp_file_path) st.error(f"Error processing video: {str(e)}") return "Error", 0.0 def main(): st.title("Deepfake Detection with GenConViT") st.markdown(""" Upload a video to detect if it's a real or fake (manipulated) facial video. This app uses the GenConViT model to analyze facial videos for signs of manipulation. """) # Load the model try: model, config = load_model_from_huggingface() st.success("✅ Model loaded successfully") except Exception as e: st.error(f"Failed to load model: {str(e)}") st.stop() # File uploader uploaded_file = st.file_uploader("Choose a video file", type=["mp4", "avi", "mov", "wmv"]) # Slider for number of frames to process num_frames = st.slider("Number of frames to process", min_value=5, max_value=30, value=15) if uploaded_file is not None: # Display the video col1, col2 = st.columns(2) with col1: st.video(uploaded_file) with col2: st.info("Processing your video...") # Process the video prediction, confidence = process_video(uploaded_file, model, config, num_frames) # Display results if prediction == "FAKE": st.error(f"RESULT: {prediction}") st.warning(f"Confidence: {confidence:.2f}") st.markdown("⚠️ This video appears to be manipulated.") elif prediction == "REAL": st.success(f"RESULT: {prediction}") st.info(f"Confidence: {confidence:.2f}") st.markdown("✅ This video appears to be authentic.") else: st.warning(f"RESULT: {prediction}") # Add information about the model with st.expander("About the Model"): st.markdown(""" **GenConViT** is a deepfake detection model that combines convolutional neural networks with vision transformers to detect manipulated facial videos. The model analyzes facial features and determines whether a video is authentic (REAL) or manipulated (FAKE). GitHub repository: [https://github.com/Deressa/GenConViT](https://github.com/Deressa/GenConViT) """) if __name__ == "__main__": main()