import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import cv2 from skimage import exposure import time # Load models (using free Hugging Face models) MODEL_NAMES = { "Model 1": "dima806/deepfake_vs_real_image_detection", "Model 2": "saltacc/anime-ai-detect", "Model 3": "rizvandwiki/gansfake-detector" } # Initialize models models = {} processors = {} for name, path in MODEL_NAMES.items(): try: processors[name] = AutoImageProcessor.from_pretrained(path) models[name] = AutoModelForImageClassification.from_pretrained(path) except: print(f"Could not load model: {name}") def analyze_image(image, selected_model): if image is None: return None, None, "Please upload an image first", None try: # Convert to RGB if needed if image.mode != "RGB": image = image.convert("RGB") # Get model and processor model = models.get(selected_model) processor = processors.get(selected_model) if not model or not processor: return None, None, "Selected model not available", None # Preprocess image inputs = processor(images=image, return_tensors="pt") # Predict with torch.no_grad(): outputs = model(**inputs) # Get probabilities probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] # Create visualizations heatmap = generate_heatmap(image, model, processor) chart_fig = create_probability_chart(probs, model.config.id2label) # Format results result_text = format_results(probs, model.config.id2label) return heatmap, chart_fig, result_text, create_model_info(selected_model) except Exception as e: return None, None, f"Error: {str(e)}", None def generate_heatmap(image, model, processor): """Generate a heatmap showing important regions for the prediction""" # Convert to numpy array img_array = np.array(image) # Create a saliency map (simple version) gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) blurred = cv2.GaussianBlur(gray, (21, 21), 0) heatmap = cv2.applyColorMap(blurred, cv2.COLORMAP_JET) # Blend with original image heatmap = cv2.addWeighted(img_array, 0.7, heatmap, 0.3, 0) # Convert back to PIL return Image.fromarray(heatmap) def create_probability_chart(probs, id2label): """Create a bar chart of class probabilities""" labels = [id2label[i] for i in range(len(probs))] colors = ['#4CAF50' if 'real' in label.lower() else '#F44336' for label in labels] fig, ax = plt.subplots(figsize=(8, 4)) bars = ax.barh(labels, probs.numpy(), color=colors) ax.set_xlim(0, 1) ax.set_title('Detection Probabilities', pad=20) ax.set_xlabel('Probability') # Add value labels for bar in bars: width = bar.get_width() ax.text(width + 0.02, bar.get_y() + bar.get_height()/2, f'{width:.2f}', va='center') plt.tight_layout() return fig def format_results(probs, id2label): """Format the results as text""" results = [] for i, prob in enumerate(probs): results.append(f"{id2label[i]}: {prob*100:.1f}%") max_prob = max(probs) max_class = id2label[torch.argmax(probs).item()] if 'real' in max_class.lower(): conclusion = f"Conclusion: This image appears to be AUTHENTIC with {max_prob*100:.1f}% confidence" else: conclusion = f"Conclusion: This image appears to be FAKE/GENERATED with {max_prob*100:.1f}% confidence" return "\n".join([conclusion, "", "Detailed probabilities:"] + results) def create_model_info(model_name): """Create information about the current model""" info = { "Model 1": "Trained to detect deepfakes vs real human faces", "Model 2": "Specialized in detecting AI-generated anime images", "Model 3": "General GAN-generated image detector" } return info.get(model_name, "No information available about this model") # Custom CSS for the interface custom_css = """ :root { --primary: #4b6cb7; --secondary: #182848; --authentic: #4CAF50; --fake: #F44336; --neutral: #2196F3; } #main-container { max-width: 1200px; margin: auto; padding: 25px; border-radius: 15px; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); box-shadow: 0 8px 32px rgba(0,0,0,0.1); } .header { text-align: center; margin-bottom: 25px; background: linear-gradient(90deg, var(--primary) 0%, var(--secondary) 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; padding: 10px; } .upload-area { border: 3px dashed var(--primary) !important; min-height: 300px; border-radius: 12px !important; transition: all 0.3s ease; } .upload-area:hover { border-color: var(--secondary) !important; transform: translateY(-2px); } .result-box { padding: 20px; border-radius: 12px; margin-top: 20px; font-size: 1.1em; transition: all 0.3s ease; box-shadow: 0 4px 6px rgba(0,0,0,0.1); background: white; } .visualization-box { border-radius: 12px; padding: 15px; background: white; margin-top: 15px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } .btn-primary { background: linear-gradient(90deg, var(--primary) 0%, var(--secondary) 100%) !important; color: white !important; border: none !important; padding: 12px 24px !important; border-radius: 8px !important; font-weight: bold !important; } .model-select { background: white !important; border: 2px solid var(--primary) !important; border-radius: 8px !important; padding: 8px 12px !important; } .footer { text-align: center; margin-top: 20px; font-size: 0.9em; color: #666; } @keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } } .animation { animation: fadeIn 0.5s ease-in-out; } .loading { animation: pulse 1.5s infinite; } @keyframes pulse { 0% { opacity: 0.6; } 50% { opacity: 1; } 100% { opacity: 0.6; } } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="main-container"): with gr.Column(elem_classes=["header"]): gr.Markdown("# 🛡️ DeepGuard AI") gr.Markdown("## Advanced Deepfake Detection System") with gr.Row(): with gr.Column(scale=1.5): image_input = gr.Image( type="pil", label="Upload Image for Analysis", elem_classes=["upload-area", "animation"] ) with gr.Row(): model_selector = gr.Dropdown( choices=list(MODEL_NAMES.keys()), value=list(MODEL_NAMES.keys())[0], label="Select Detection Model", elem_classes=["model-select", "animation"] ) analyze_btn = gr.Button( "Analyze Image", elem_classes=["btn-primary", "animation"] ) with gr.Column(scale=1): with gr.Column(elem_classes=["visualization-box"]): heatmap_output = gr.Image( label="Attention Heatmap", interactive=False ) with gr.Column(elem_classes=["visualization-box"]): chart_output = gr.Plot( label="Detection Probabilities" ) with gr.Column(elem_classes=["result-box", "animation"]): result_output = gr.Textbox( label="Analysis Results", interactive=False, lines=8 ) with gr.Column(elem_classes=["result-box", "animation"]): model_info = gr.Textbox( label="Model Information", interactive=False, lines=3 ) gr.Markdown("""
""") # Event handlers analyze_btn.click( fn=analyze_image, inputs=[image_input, model_selector], outputs=[heatmap_output, chart_output, result_output, model_info] ) if __name__ == "__main__": demo.launch()