phiBert / app.py
Titembaye's picture
Optimize: Force CPU mode and add low memory usage for HF Spaces
288834c
"""
Application Gradio - Détecteur de Phishing
Modèle: BERT fine-tuné avec adversarial training
"""
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
# Configuration
MODEL_PATH = "models"
MAX_LENGTH = 256
DEVICE = torch.device("cpu") # Force CPU pour Hugging Face Spaces gratuit
print("="*60)
print("🚀 Initialisation du Détecteur de Phishing")
print("="*60)
# Vérifier que le modèle existe
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(
f"❌ Modèle introuvable: {MODEL_PATH}\n"
f" Assurez-vous que le dossier existe et contient les fichiers du modèle."
)
# Charger le tokenizer et le modèle
print(f"📥 Chargement du tokenizer depuis {MODEL_PATH}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print(f"📥 Chargement du modèle depuis {MODEL_PATH}...")
# Charger en mode CPU avec optimisations mémoire
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float32, # Utiliser float32 pour compatibilité CPU
low_cpu_mem_usage=True # Optimisation mémoire
)
model.to(DEVICE)
model.eval()
print(f"✅ Modèle chargé avec succès!")
print(f"🖥️ Device: {DEVICE}")
print("="*60 + "\n")
def predict_phishing(email_text):
"""
Prédit si un email est du phishing ou légitime
Args:
email_text (str): Texte de l'email à analyser
Returns:
tuple: (verdict, probabilités, analyse détaillée)
"""
if not email_text.strip():
return "⚠️ Veuillez entrer un email", {}, ""
# Tokenization
inputs = tokenizer(
email_text,
max_length=MAX_LENGTH,
padding='max_length',
truncation=True,
return_tensors='pt'
)
# Déplacer sur le bon device
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
# Prédiction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)[0]
predicted_class = torch.argmax(probabilities).item()
confidence = probabilities[predicted_class].item()
# Résultats
label = "🚨 Phishing Détecté" if predicted_class == 1 else "✅ Email Légitime"
prob_dict = {
"Légitime": float(probabilities[0]),
"Phishing": float(probabilities[1])
}
# Analyse détaillée
analysis = f"""
### 📊 Résultats de l'analyse
**Verdict:** {label}
**Confiance:** {confidence * 100:.1f}%
### 🔍 Détails des probabilités
- **Légitime:** {probabilities[0] * 100:.2f}%
- **Phishing:** {probabilities[1] * 100:.2f}%
### 📝 Informations
- **Modèle:** BERT-base-uncased (adversarial training)
- **Longueur du texte:** {len(email_text)} caractères
- **Tokens:** {len(tokenizer.encode(email_text))} tokens
### ⚠️ Avertissement
Cette analyse est fournie à titre éducatif uniquement. En cas de doute sur un email réel,
contactez votre service informatique ou l'expéditeur présumé par un canal sécurisé.
"""
return label, prob_dict, analysis
# Exemples d'emails pour la démo
examples = [
["""Dear valued customer,
Your account has been suspended due to unusual activity.
Please verify your identity immediately by clicking the link below:
http://secure-verify-account.com/login
You have 24 hours to verify or your account will be permanently closed.
Best regards,
Security Team"""],
["""Hi team,
Just a reminder that our weekly meeting is scheduled for tomorrow at 2 PM in Conference Room B.
Please bring your project updates.
Thanks,
John"""],
["""URGENT: You have won $1,000,000 in the international lottery!
To claim your prize, send us your bank details and a processing fee of $500.
Contact us immediately: winner@lottery-prize.com
Congratulations!"""],
["""Hello,
Your package delivery failed.
Track your package here: https://trackpackage.com/xyz123
Delivery company will retry tomorrow between 9 AM - 5 PM.
Tracking ID: XYZ123456"""]
]
# Interface Gradio
with gr.Blocks(theme=gr.themes.Soft(), title="Détecteur de Phishing") as demo:
gr.Markdown("""
# 🛡️ Détecteur de Phishing par Intelligence Artificielle
Cette application utilise un modèle **BERT fine-tuné avec adversarial training**
pour détecter les emails de phishing.
**Axes d'évaluation:**
- 🎯 Robustesse face aux attaques adversariales générées par IA
- 🌐 Généralisation cross-linguale (EN/FR)
---
""")
with gr.Row():
with gr.Column(scale=2):
email_input = gr.Textbox(
label="📧 Collez votre email ici",
placeholder="Entrez le contenu de l'email à analyser...",
lines=10,
max_lines=20
)
with gr.Row():
analyze_btn = gr.Button("🔍 Analyser", variant="primary", size="lg")
clear_btn = gr.ClearButton([email_input], value="🗑️ Effacer")
with gr.Column(scale=1):
verdict_output = gr.Textbox(
label="🎯 Verdict",
interactive=False,
lines=2
)
prob_output = gr.Label(
label="📊 Probabilités",
num_top_classes=2
)
with gr.Row():
analysis_output = gr.Markdown(label="📈 Analyse Détaillée")
# Exemples
gr.Markdown("### 💡 Exemples à tester")
gr.Examples(
examples=examples,
inputs=email_input,
label="Cliquez sur un exemple pour le tester"
)
# Footer
gr.Markdown("""
---
### 📚 À propos
**Projet:** Détection de Phishing par IA - Robustesse Adversariale et Généralisation Cross-Linguale
**Datasets utilisés:**
- Enron Email Dataset (500k emails)
- SMS Spam Collection (5,574 SMS)
- Phishing Email Dataset (18,650 emails)
- Phishing adversariaux générés par Ollama + Gemma3:1b
**Modèle:**
- BERT-base-uncased (110M paramètres)
- Fine-tuné avec adversarial training (50% baseline + 50% adversarial)
⚠️ **Disclaimer:** Cette application est fournie à des fins éducatives et de recherche uniquement.
""")
# Actions
analyze_btn.click(
fn=predict_phishing,
inputs=email_input,
outputs=[verdict_output, prob_output, analysis_output]
)
if __name__ == "__main__":
print("\n" + "="*60)
print("🚀 Lancement de l'application Gradio")
print("="*60)
print(f"📱 Device: {DEVICE}")
print(f"🤖 Modèle: {MODEL_PATH}")
print("="*60 + "\n")
demo.launch(
show_error=True,
server_name="0.0.0.0", # Nécessaire pour Hugging Face Spaces
server_port=7860
)