deepsaif / app.py
22GC22's picture
Upload 17 files
2fbad05 verified
import streamlit as st
from PIL import Image
from models import mesonet, mesoinception, fal_detector, local_detector
from utils.visualization import display_results
from utils.preprocessing import preprocess_image, preprocess_video
from utils.preprocessing import preprocess_image, generate_local_image
import numpy as np
import cv2
import tempfile
import os
# Initialize session state for tabs and uploaded files
if "active_tab" not in st.session_state:
st.session_state["active_tab"] = "Face Photoshop Detection"
if "uploaded_file" not in st.session_state:
st.session_state["uploaded_file"] = None
# Load models
models = {
"MesoNet": mesonet.load_mesonet("models/weights/Meso4_DF.h5"),
"MesoInception": mesoinception.load_mesonetInception("models/weights/MesoInception_DF.h5"),
"Photoshop FALdetector Global": fal_detector.load_fal_detector("models/weights/global.pth"),
"Photoshop FALdetector Local": local_detector.load_local_detector("models/weights/local.pth", gpu_id=-1),
}
st.title("DeepSAIF")
# Create tabs for different functionalities
tab1, tab2, tab3 = st.tabs(["Face Photoshop Detection", "DeepFake Detection for Images", "DeepFake Detection for Videos"])
# Tab 1: Photoshop Detection
with tab1:
if st.session_state["active_tab"] != "Face Photoshop Detection":
st.session_state["uploaded_file"] = None
st.session_state["active_tab"] = "Face Photoshop Detection"
st.header("Face Photoshop Detection")
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png"], key="photoshop")
if uploaded_file:
st.session_state["uploaded_file"] = uploaded_file
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
local_image = generate_local_image(image)
# Run inference on all models
results = {}
for model_name, model in models.items():
if model_name == "Photoshop FALdetector Global":
results[model_name] = fal_detector.predict_fal_detector(model, image)
elif model_name == "Photoshop FALdetector Local":
heatmap_path, prediction = local_detector.predict_and_generate_heatmap(model, image)
if heatmap_path:
# Display the heatmap using Streamlit
st.image(heatmap_path, caption=f"Heatmap for {model_name}", use_container_width=True)
# Delete the temporary heatmap file after display
os.remove(heatmap_path)
os.remove('cropped_input.jpg')
os.remove('warped.jpg')
else:
st.error(f"Failed to generate heatmap for {model_name}")
results[model_name] = prediction
# elif model_name == "Global Classifier":
# results[model_name] = global_classifier.classify_fake(model, image)
# Display results
display_results(results)
# Tab 2: DeepFake Detection for Images
with tab2:
if st.session_state["active_tab"] != "DeepFake Detection for Images":
st.session_state["uploaded_file"] = None
st.session_state["active_tab"] = "DeepFake Detection for Images"
st.header("DeepFake Detection for Images")
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png"], key="deepfake_image")
if uploaded_file:
st.session_state["uploaded_file"] = uploaded_file
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
local_image = generate_local_image(image)
# Preprocess the image
# preprocessed_image = preprocess_image(uploaded_file)
# Run inference on all models
results = {}
for model_name, model in models.items():
if model_name == "MesoNet":
results[model_name] = mesonet.predict_mesonet(model, image)
elif model_name == "MesoInception":
results[model_name] = mesoinception.predict_mesonetInception(model, image)
# Display results
display_results(results)
def confident_strategy(pred, t=0.8):
"""
Implements the confident averaging strategy for predictions.
Args:
pred (list[float]): List of predictions for each frame.
t (float): Threshold for high-confidence fake detection.
Returns:
float: Final confidence score for the video.
"""
if len(pred) == 0:
return np.nan
pred = np.array(pred)
sz = len(pred)
fakes = np.count_nonzero(pred > t)
if fakes > sz // 2.5 and fakes > 11:
return np.mean(pred[pred > t])
elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
return np.mean(pred[pred < 0.2])
else:
return np.mean(pred)
# Tab 3: DeepFake Detection for Videos
with tab3:
if st.session_state["active_tab"] != "DeepFake Detection for Videos":
st.session_state["uploaded_file"] = None
st.session_state["active_tab"] = "DeepFake Detection for Videos"
st.header("DeepFake Detection for Videos")
uploaded_file = st.file_uploader("Upload Video", type=["mp4", "avi", "mov"], key="deepfake_video")
if uploaded_file:
st.session_state["uploaded_file"] = uploaded_file
with st.spinner("Processing video..."):
# Save uploaded file to a temporary location
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
temp_file.write(uploaded_file.getbuffer())
video_path = temp_file.name
try:
# Test video accessibility
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
st.error("Failed to open video file.")
else:
st.success("Video file opened successfully!")
# Extract frames from video
frames = preprocess_video(video_path, frame_count=32)
if len(frames) == 0:
st.error("Failed to extract frames from the video.")
else:
# st.success(f"Extracted {len(frames)} frames.")
# for frame in frames[:5]: # Display first 5 frames
# st.image(frame, caption="Extracted Frame")
# Dictionary to store model predictions
model_results = {
"MesoNet": [],
"Photoshop FALdetector Global": []
}
# Iterate over frames and make predictions for each model
for frame in frames:
preprocessed_frame = preprocess_image(frame) # Preprocess frame
local_image = generate_local_image(preprocessed_frame)
# Predictions for MesoNet and Photoshop FALdetector Global
model_results["MesoNet"].append(
mesonet.predict_mesonet(models["MesoNet"], preprocessed_frame)
)
model_results["Photoshop FALdetector Global"].append(
fal_detector.predict_fal_detector(models["Photoshop FALdetector Global"], local_image)
)
# Apply the confident averaging strategy for each model
final_results = {}
for model_name, predictions in model_results.items():
final_results[model_name] = confident_strategy(predictions)
# Display results
st.write("### Video Analysis Results")
display_results(final_results)
# Optionally show detailed frame predictions per model
if st.checkbox("Show Detailed Frame Predictions"):
for model_name, predictions in model_results.items():
st.write(f"### Predictions for {model_name}")
st.bar_chart(predictions)
finally:
# Clean up temporary file
os.remove(video_path)