Spaces:
Running
Running
Upload 38 files
Browse files- Dockerfile +45 -0
- __init__.py +1 -0
- admin.html +250 -0
- background.js +604 -0
- bert_analyzer.py +375 -0
- bert_finetune.py +216 -0
- cnn_inference.py +237 -0
- cnn_model.py +121 -0
- config.json +35 -0
- content.js +180 -0
- data_collector.py +364 -0
- domain_graph_builder.py +303 -0
- email_analyzer.py +191 -0
- feedback_store.py +223 -0
- generate_icons.py +98 -0
- gmail_scanner.js +193 -0
- gnn_inference.py +274 -0
- gnn_model.py +183 -0
- keep_alive.py +50 -0
- main.py +699 -0
- manifest.json +37 -0
- popup.html +432 -0
- popup.js +332 -0
- render.yaml +34 -0
- requirements.txt +18 -0
- retraining_service.py +295 -0
- screenshot_collector.py +172 -0
- screenshot_hasher.py +214 -0
- special_tokens_map.json +37 -0
- test_endpoint.py +34 -0
- tier3_bert_gnn.py +153 -0
- tokenizer.json +0 -0
- tokenizer_config.json +55 -0
- train_cnn.py +277 -0
- train_gnn.py +228 -0
- url_heuristics.py +326 -0
- visual_analyzer.py +512 -0
- vocab.txt +0 -0
Dockerfile
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PYTHONUNBUFFERED=1 \
|
| 7 |
+
PLAYWRIGHT_BROWSERS_PATH=/ms-playwright \
|
| 8 |
+
HOME=/home/user \
|
| 9 |
+
PATH=/home/user/.local/bin:$PATH
|
| 10 |
+
|
| 11 |
+
# Create a non-root user for Hugging Face security
|
| 12 |
+
RUN useradd -m -u 1000 user
|
| 13 |
+
WORKDIR /home/user/app
|
| 14 |
+
|
| 15 |
+
# Install system dependencies required for Playwright and ML libs
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
wget curl libglib2.0-0 libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 \
|
| 18 |
+
libcups2 libdrm2 libdbus-1-3 libxcb1 libxkbcommon0 libx11-6 libxcomposite1 \
|
| 19 |
+
libxdamage1 libxext6 libxfixes3 libxrandr2 libgbm1 libpango-1.0-0 libcairo2 libasound2 \
|
| 20 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 21 |
+
|
| 22 |
+
# Install PyTorch CPU
|
| 23 |
+
RUN pip install --no-cache-dir torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cpu
|
| 24 |
+
|
| 25 |
+
# Copy requirements and install
|
| 26 |
+
COPY --chown=user requirements.txt .
|
| 27 |
+
RUN grep -v "torch==" requirements.txt | grep -v "torchvision==" > req_filtered.txt && \
|
| 28 |
+
pip install --no-cache-dir --upgrade pip && \
|
| 29 |
+
pip install --no-cache-dir -r req_filtered.txt
|
| 30 |
+
|
| 31 |
+
# Install Playwright browser
|
| 32 |
+
RUN playwright install chromium
|
| 33 |
+
|
| 34 |
+
# Copy project files
|
| 35 |
+
COPY --chown=user . .
|
| 36 |
+
|
| 37 |
+
# Create necessary directories and set permissions
|
| 38 |
+
RUN mkdir -p data logs bert_weights gnn cnn && \
|
| 39 |
+
chmod -R 777 data logs bert_weights
|
| 40 |
+
|
| 41 |
+
# Expose the Hugging Face default port
|
| 42 |
+
EXPOSE 7860
|
| 43 |
+
|
| 44 |
+
# Command to run on start (Port 7860 is required for HF)
|
| 45 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# PhishGuard AI - CNN Module
|
admin.html
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>PhishGuard Admin</title>
|
| 7 |
+
<style>
|
| 8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
|
| 9 |
+
* { margin:0; padding:0; box-sizing:border-box; }
|
| 10 |
+
:root {
|
| 11 |
+
--bg: #0F0F14; --bg2: #1A1A24; --card: #22222E; --border: rgba(255,255,255,0.06);
|
| 12 |
+
--text: #EAEAF0; --text2: #8888A0; --accent: #534AB7;
|
| 13 |
+
--safe: #22C55E; --danger: #EF4444; --warn: #F59E0B;
|
| 14 |
+
}
|
| 15 |
+
body { font-family:'Inter',sans-serif; background:var(--bg); color:var(--text); min-height:100vh; }
|
| 16 |
+
|
| 17 |
+
/* Login */
|
| 18 |
+
.login-wrap { display:flex; align-items:center; justify-content:center; min-height:100vh; }
|
| 19 |
+
.login-box { background:var(--card); padding:36px; border-radius:16px; border:1px solid var(--border); width:340px; }
|
| 20 |
+
.login-box h2 { font-size:20px; margin-bottom:20px; text-align:center; }
|
| 21 |
+
.login-box input { width:100%; padding:10px 14px; background:var(--bg2); border:1px solid var(--border);
|
| 22 |
+
border-radius:8px; color:var(--text); font-size:14px; font-family:inherit; margin-bottom:12px; outline:none; }
|
| 23 |
+
.login-box input:focus { border-color:var(--accent); }
|
| 24 |
+
.login-box button { width:100%; padding:10px; background:linear-gradient(135deg,var(--accent),#6C5ECE);
|
| 25 |
+
border:none; border-radius:8px; color:#fff; font-size:14px; font-weight:600; cursor:pointer; font-family:inherit; }
|
| 26 |
+
.login-box button:hover { opacity:0.9; }
|
| 27 |
+
.login-error { color:var(--danger); font-size:12px; text-align:center; margin-top:8px; display:none; }
|
| 28 |
+
|
| 29 |
+
/* Dashboard */
|
| 30 |
+
.dashboard { display:none; max-width:1000px; margin:0 auto; padding:24px; }
|
| 31 |
+
.dash-header { display:flex; align-items:center; gap:12px; margin-bottom:24px; }
|
| 32 |
+
.dash-header h1 { font-size:22px; flex:1; }
|
| 33 |
+
.dash-header h1 span { color:var(--accent); }
|
| 34 |
+
.logout-btn { padding:6px 16px; background:var(--card); border:1px solid var(--border);
|
| 35 |
+
border-radius:8px; color:var(--text2); font-size:12px; cursor:pointer; font-family:inherit; }
|
| 36 |
+
|
| 37 |
+
/* Stats Cards */
|
| 38 |
+
.stats { display:grid; grid-template-columns:repeat(auto-fit,minmax(180px,1fr)); gap:12px; margin-bottom:24px; }
|
| 39 |
+
.stat-card { background:var(--card); border:1px solid var(--border); border-radius:12px; padding:16px; }
|
| 40 |
+
.stat-label { font-size:11px; color:var(--text2); text-transform:uppercase; letter-spacing:0.5px; }
|
| 41 |
+
.stat-value { font-size:28px; font-weight:700; margin-top:4px; }
|
| 42 |
+
.stat-sub { font-size:11px; color:var(--text2); margin-top:2px; }
|
| 43 |
+
|
| 44 |
+
/* Table */
|
| 45 |
+
.section-title { font-size:16px; font-weight:600; margin-bottom:12px; }
|
| 46 |
+
.table-wrap { background:var(--card); border:1px solid var(--border); border-radius:12px; overflow:hidden; margin-bottom:24px; }
|
| 47 |
+
table { width:100%; border-collapse:collapse; font-size:12px; }
|
| 48 |
+
th { background:var(--bg2); padding:10px 14px; text-align:left; font-weight:600; color:var(--text2);
|
| 49 |
+
text-transform:uppercase; letter-spacing:0.5px; font-size:10px; }
|
| 50 |
+
td { padding:10px 14px; border-top:1px solid var(--border); }
|
| 51 |
+
tr:hover td { background:rgba(255,255,255,0.02); }
|
| 52 |
+
.badge { display:inline-block; padding:2px 8px; border-radius:4px; font-size:10px; font-weight:600; }
|
| 53 |
+
.badge-phish { background:rgba(239,68,68,0.12); color:var(--danger); }
|
| 54 |
+
.badge-safe { background:rgba(34,197,94,0.12); color:var(--safe); }
|
| 55 |
+
.url-cell { max-width:300px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; font-family:'SF Mono',monospace; font-size:11px; color:var(--text2); }
|
| 56 |
+
|
| 57 |
+
/* Retrain Button */
|
| 58 |
+
.retrain-bar { display:flex; align-items:center; gap:12px; margin-bottom:24px; }
|
| 59 |
+
.retrain-btn { padding:10px 24px; background:linear-gradient(135deg,var(--accent),#6C5ECE);
|
| 60 |
+
border:none; border-radius:8px; color:#fff; font-size:13px; font-weight:600; cursor:pointer; font-family:inherit; }
|
| 61 |
+
.retrain-btn:hover { opacity:0.9; }
|
| 62 |
+
.retrain-btn:disabled { opacity:0.4; cursor:not-allowed; }
|
| 63 |
+
.retrain-status { font-size:12px; color:var(--text2); }
|
| 64 |
+
|
| 65 |
+
/* History */
|
| 66 |
+
.history-card { background:var(--card); border:1px solid var(--border); border-radius:12px; padding:16px; margin-bottom:8px;
|
| 67 |
+
display:flex; align-items:center; gap:16px; }
|
| 68 |
+
.hist-version { font-size:22px; font-weight:700; color:var(--accent); min-width:48px; text-align:center; }
|
| 69 |
+
.hist-detail { flex:1; }
|
| 70 |
+
.hist-detail div:first-child { font-size:13px; font-weight:500; }
|
| 71 |
+
.hist-detail div:last-child { font-size:11px; color:var(--text2); margin-top:2px; }
|
| 72 |
+
.hist-accuracy { font-size:14px; font-weight:600; }
|
| 73 |
+
</style>
|
| 74 |
+
</head>
|
| 75 |
+
<body>
|
| 76 |
+
|
| 77 |
+
<!-- Login Screen -->
|
| 78 |
+
<div class="login-wrap" id="loginScreen">
|
| 79 |
+
<div class="login-box">
|
| 80 |
+
<h2>π‘οΈ PhishGuard Admin</h2>
|
| 81 |
+
<input type="password" id="passInput" placeholder="Admin password" autocomplete="off">
|
| 82 |
+
<button onclick="attemptLogin()">Login</button>
|
| 83 |
+
<div class="login-error" id="loginError">Invalid password</div>
|
| 84 |
+
</div>
|
| 85 |
+
</div>
|
| 86 |
+
|
| 87 |
+
<!-- Dashboard (hidden until login) -->
|
| 88 |
+
<div class="dashboard" id="dashboard">
|
| 89 |
+
<div class="dash-header">
|
| 90 |
+
<h1>Phish<span>Guard</span> Admin</h1>
|
| 91 |
+
<button class="logout-btn" onclick="logout()">Logout</button>
|
| 92 |
+
</div>
|
| 93 |
+
|
| 94 |
+
<!-- Stats -->
|
| 95 |
+
<div class="stats">
|
| 96 |
+
<div class="stat-card">
|
| 97 |
+
<div class="stat-label">Total Feedback</div>
|
| 98 |
+
<div class="stat-value" id="sTotalFeedback">β</div>
|
| 99 |
+
</div>
|
| 100 |
+
<div class="stat-card">
|
| 101 |
+
<div class="stat-label">Phishing Reports</div>
|
| 102 |
+
<div class="stat-value" id="sPhishing" style="color:var(--danger)">β</div>
|
| 103 |
+
</div>
|
| 104 |
+
<div class="stat-card">
|
| 105 |
+
<div class="stat-label">Safe Reports</div>
|
| 106 |
+
<div class="stat-value" id="sSafe" style="color:var(--safe)">β</div>
|
| 107 |
+
</div>
|
| 108 |
+
<div class="stat-card">
|
| 109 |
+
<div class="stat-label">Model Version</div>
|
| 110 |
+
<div class="stat-value" id="sVersion" style="color:var(--accent)">β</div>
|
| 111 |
+
<div class="stat-sub" id="sLastRetrain">Never retrained</div>
|
| 112 |
+
</div>
|
| 113 |
+
<div class="stat-card">
|
| 114 |
+
<div class="stat-label">Unprocessed</div>
|
| 115 |
+
<div class="stat-value" id="sUnprocessed" style="color:var(--warn)">β</div>
|
| 116 |
+
<div class="stat-sub">of 50 needed</div>
|
| 117 |
+
</div>
|
| 118 |
+
</div>
|
| 119 |
+
|
| 120 |
+
<!-- Manual Retrain -->
|
| 121 |
+
<div class="retrain-bar">
|
| 122 |
+
<button class="retrain-btn" id="retrainBtn" onclick="triggerRetrain()">π Trigger Retraining</button>
|
| 123 |
+
<div class="retrain-status" id="retrainStatus"></div>
|
| 124 |
+
</div>
|
| 125 |
+
|
| 126 |
+
<!-- Recent Feedback -->
|
| 127 |
+
<div class="section-title">Recent Feedback</div>
|
| 128 |
+
<div class="table-wrap">
|
| 129 |
+
<table>
|
| 130 |
+
<thead>
|
| 131 |
+
<tr><th>URL</th><th>Label</th><th>Source</th><th>Prediction</th><th>Time</th></tr>
|
| 132 |
+
</thead>
|
| 133 |
+
<tbody id="feedbackTable"><tr><td colspan="5" style="text-align:center;color:var(--text2)">Loading...</td></tr></tbody>
|
| 134 |
+
</table>
|
| 135 |
+
</div>
|
| 136 |
+
|
| 137 |
+
<!-- Retrain History -->
|
| 138 |
+
<div class="section-title">Retrain History</div>
|
| 139 |
+
<div id="historyList"><div style="color:var(--text2);font-size:12px">Loading...</div></div>
|
| 140 |
+
</div>
|
| 141 |
+
|
| 142 |
+
<script>
|
| 143 |
+
const BASE = window.location.origin;
|
| 144 |
+
let authToken = "";
|
| 145 |
+
|
| 146 |
+
// Login
|
| 147 |
+
function attemptLogin() {
|
| 148 |
+
const pass = document.getElementById("passInput").value;
|
| 149 |
+
fetch(`${BASE}/admin/login`, {
|
| 150 |
+
method: "POST",
|
| 151 |
+
headers: {"Content-Type":"application/json"},
|
| 152 |
+
body: JSON.stringify({password: pass})
|
| 153 |
+
})
|
| 154 |
+
.then(r => r.json())
|
| 155 |
+
.then(data => {
|
| 156 |
+
if (data.success) {
|
| 157 |
+
authToken = data.token;
|
| 158 |
+
document.getElementById("loginScreen").style.display = "none";
|
| 159 |
+
document.getElementById("dashboard").style.display = "block";
|
| 160 |
+
loadDashboard();
|
| 161 |
+
} else {
|
| 162 |
+
document.getElementById("loginError").style.display = "block";
|
| 163 |
+
}
|
| 164 |
+
})
|
| 165 |
+
.catch(() => { document.getElementById("loginError").style.display = "block"; });
|
| 166 |
+
}
|
| 167 |
+
document.getElementById("passInput").addEventListener("keyup", e => { if(e.key==="Enter") attemptLogin(); });
|
| 168 |
+
|
| 169 |
+
function logout() {
|
| 170 |
+
authToken = "";
|
| 171 |
+
document.getElementById("loginScreen").style.display = "flex";
|
| 172 |
+
document.getElementById("dashboard").style.display = "none";
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// Dashboard data
|
| 176 |
+
function loadDashboard() {
|
| 177 |
+
// Stats
|
| 178 |
+
fetch(`${BASE}/admin/data?token=${authToken}`).then(r=>r.json()).then(data => {
|
| 179 |
+
if (data.error) { logout(); return; }
|
| 180 |
+
const s = data.stats;
|
| 181 |
+
document.getElementById("sTotalFeedback").textContent = s.total_feedback;
|
| 182 |
+
document.getElementById("sPhishing").textContent = s.phishing_corrections;
|
| 183 |
+
document.getElementById("sSafe").textContent = s.safe_corrections;
|
| 184 |
+
document.getElementById("sVersion").textContent = "v" + s.model_version;
|
| 185 |
+
document.getElementById("sUnprocessed").textContent = s.unprocessed_count;
|
| 186 |
+
document.getElementById("sLastRetrain").textContent = s.last_retrain
|
| 187 |
+
? "Last: " + new Date(s.last_retrain).toLocaleString() : "Never retrained";
|
| 188 |
+
|
| 189 |
+
// Feedback table
|
| 190 |
+
const rows = data.recent.map(e => `
|
| 191 |
+
<tr>
|
| 192 |
+
<td class="url-cell" title="${esc(e.url)}">${esc(e.url)}</td>
|
| 193 |
+
<td><span class="badge ${e.label==='phishing'?'badge-phish':'badge-safe'}">${esc(e.label)}</span></td>
|
| 194 |
+
<td>${esc(e.source||'β')}</td>
|
| 195 |
+
<td>${e.original_prediction!=null ? (e.original_prediction*100).toFixed(0)+'%' : 'β'}</td>
|
| 196 |
+
<td style="font-size:11px;color:var(--text2)">${e.timestamp ? new Date(e.timestamp).toLocaleString() : 'β'}</td>
|
| 197 |
+
</tr>
|
| 198 |
+
`).join("");
|
| 199 |
+
document.getElementById("feedbackTable").innerHTML = rows || '<tr><td colspan="5" style="text-align:center;color:var(--text2)">No feedback yet</td></tr>';
|
| 200 |
+
|
| 201 |
+
// History
|
| 202 |
+
const hist = (s.retrain_history || []).reverse();
|
| 203 |
+
if (hist.length === 0) {
|
| 204 |
+
document.getElementById("historyList").innerHTML = '<div style="color:var(--text2);font-size:12px">No retraining history</div>';
|
| 205 |
+
} else {
|
| 206 |
+
document.getElementById("historyList").innerHTML = hist.map(h => `
|
| 207 |
+
<div class="history-card">
|
| 208 |
+
<div class="hist-version">v${h.version}</div>
|
| 209 |
+
<div class="hist-detail">
|
| 210 |
+
<div>Trained on ${h.samples} samples</div>
|
| 211 |
+
<div>${new Date(h.timestamp).toLocaleString()}</div>
|
| 212 |
+
</div>
|
| 213 |
+
<div class="hist-accuracy" style="color:${h.accuracy>=0.8?'var(--safe)':h.accuracy>=0.6?'var(--warn)':'var(--danger)'}">
|
| 214 |
+
${(h.accuracy*100).toFixed(1)}%
|
| 215 |
+
</div>
|
| 216 |
+
</div>
|
| 217 |
+
`).join("");
|
| 218 |
+
}
|
| 219 |
+
});
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
// Retrain
|
| 223 |
+
function triggerRetrain() {
|
| 224 |
+
const btn = document.getElementById("retrainBtn");
|
| 225 |
+
btn.disabled = true;
|
| 226 |
+
btn.textContent = "β³ Retraining...";
|
| 227 |
+
document.getElementById("retrainStatus").textContent = "Training in progress...";
|
| 228 |
+
|
| 229 |
+
fetch(`${BASE}/admin/retrain?token=${authToken}`, {method:"POST"})
|
| 230 |
+
.then(r=>r.json())
|
| 231 |
+
.then(data => {
|
| 232 |
+
document.getElementById("retrainStatus").textContent = data.message || "Done";
|
| 233 |
+
btn.disabled = false;
|
| 234 |
+
btn.textContent = "π Trigger Retraining";
|
| 235 |
+
setTimeout(loadDashboard, 2000);
|
| 236 |
+
})
|
| 237 |
+
.catch(e => {
|
| 238 |
+
document.getElementById("retrainStatus").textContent = "Error: " + e.message;
|
| 239 |
+
btn.disabled = false;
|
| 240 |
+
btn.textContent = "π Trigger Retraining";
|
| 241 |
+
});
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
function esc(s) { const d=document.createElement('div'); d.textContent=String(s||''); return d.innerHTML; }
|
| 245 |
+
|
| 246 |
+
// Auto-refresh every 30s
|
| 247 |
+
setInterval(() => { if(authToken) loadDashboard(); }, 30000);
|
| 248 |
+
</script>
|
| 249 |
+
</body>
|
| 250 |
+
</html>
|
background.js
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// ============================================================
|
| 2 |
+
// PhishGuard AI - background.js
|
| 3 |
+
// MV3 Service Worker with feedback, retraining triggers, and
|
| 4 |
+
// model version polling.
|
| 5 |
+
//
|
| 6 |
+
// State (chrome.storage.local):
|
| 7 |
+
// phishguard_feedback_queue: FeedbackRecord[] (max 500, FIFO)
|
| 8 |
+
// scan_count: int (resets at 50)
|
| 9 |
+
// feedback_count: int (labeled samples since last retrain)
|
| 10 |
+
// last_retrain_ts: ISO8601
|
| 11 |
+
// model_version: int
|
| 12 |
+
// session_id: UUIDv4
|
| 13 |
+
//
|
| 14 |
+
// Triggers:
|
| 15 |
+
// 1. scan_count >= 50 AND feedback_count >= 10
|
| 16 |
+
// 2. chrome.alarms "retrain_alarm" (24h) AND feedback_count >= 10
|
| 17 |
+
// ============================================================
|
| 18 |
+
|
| 19 |
+
// ββ Backend URL ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
const BACKEND_URL = "https://phishguard-api-z2wj.onrender.com";
|
| 21 |
+
const ANALYZE_URL = `${BACKEND_URL}/analyze`;
|
| 22 |
+
const RETRAIN_URL = `${BACKEND_URL}/retrain`;
|
| 23 |
+
const MODEL_VERSION_URL = `${BACKEND_URL}/model_version`;
|
| 24 |
+
|
| 25 |
+
// ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
const CACHE_TTL_MS = 30 * 60 * 1000;
|
| 27 |
+
const MAX_QUEUE_SIZE = 500;
|
| 28 |
+
const RETRAIN_URL_THRESHOLD = 50;
|
| 29 |
+
const MIN_LABELED_SAMPLES = 10;
|
| 30 |
+
|
| 31 |
+
// ββ In-memory caches βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
const urlCache = new Map();
|
| 33 |
+
const tabResultCache = new Map();
|
| 34 |
+
const pageSignals = new Map();
|
| 35 |
+
|
| 36 |
+
// ββ TIER 1: Whitelist (O(1) Set lookup) ββββββββββββββββββββββββββββββ
|
| 37 |
+
const WHITELIST = new Set([
|
| 38 |
+
"google.com","youtube.com","facebook.com","amazon.com","wikipedia.org",
|
| 39 |
+
"twitter.com","instagram.com","linkedin.com","microsoft.com","apple.com",
|
| 40 |
+
"github.com","stackoverflow.com","reddit.com","netflix.com","paypal.com",
|
| 41 |
+
"bankofamerica.com","chase.com","wellsfargo.com","yahoo.com","bing.com",
|
| 42 |
+
"outlook.com","office.com","live.com","adobe.com","dropbox.com",
|
| 43 |
+
"zoom.us","slack.com","spotify.com","twitch.tv","ebay.com",
|
| 44 |
+
"walmart.com","target.com","bestbuy.com","airbnb.com",
|
| 45 |
+
"x.com","tiktok.com","pinterest.com","quora.com","medium.com"
|
| 46 |
+
]);
|
| 47 |
+
|
| 48 |
+
function getRootDomain(url) {
|
| 49 |
+
try {
|
| 50 |
+
const host = new URL(url).hostname.replace(/^www\./, "");
|
| 51 |
+
const parts = host.split(".");
|
| 52 |
+
return parts.slice(-2).join(".");
|
| 53 |
+
} catch { return null; }
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// ββ TIER 2: Local heuristic scoring ββββββββββββββββββββββββββββββββββ
|
| 57 |
+
function heuristicScore(url) {
|
| 58 |
+
let score = 0;
|
| 59 |
+
const signals = [];
|
| 60 |
+
const u = url.toLowerCase();
|
| 61 |
+
|
| 62 |
+
// IP as hostname (25 pts)
|
| 63 |
+
if (/https?:\/\/\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/.test(url)) {
|
| 64 |
+
score += 25; signals.push("IP as hostname");
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// Suspicious TLD (20 pts)
|
| 68 |
+
const badTLDs = [".xyz",".tk",".ml",".ga",".cf",".gq",".pw",".top",".click"];
|
| 69 |
+
for (const tld of badTLDs) {
|
| 70 |
+
if (u.includes(tld)) { score += 20; signals.push(`Suspicious TLD (${tld})`); break; }
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Phishing keywords (15 pts)
|
| 74 |
+
const keywords = ["login","verify","secure","update","account","banking",
|
| 75 |
+
"signin","reset","confirm","suspend","webscr","cmd","payment","alert"];
|
| 76 |
+
const kwHits = keywords.filter(kw => u.includes(kw));
|
| 77 |
+
if (kwHits.length > 0) { score += 15; signals.push(`Keywords: ${kwHits.join(", ")}`); }
|
| 78 |
+
|
| 79 |
+
// Brand spoofing (15 pts)
|
| 80 |
+
const brands = ["paypal","google","apple","microsoft","amazon","netflix",
|
| 81 |
+
"facebook","instagram","chase","wellsfargo","bankofamerica"];
|
| 82 |
+
try {
|
| 83 |
+
const domain = getRootDomain(url);
|
| 84 |
+
for (const brand of brands) {
|
| 85 |
+
if (u.includes(brand) && domain && !domain.startsWith(brand)) {
|
| 86 |
+
score += 15; signals.push(`Brand spoofing: ${brand}`); break;
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
} catch {}
|
| 90 |
+
|
| 91 |
+
// Excessive subdomains (10 pts)
|
| 92 |
+
try {
|
| 93 |
+
const host = new URL(url).hostname;
|
| 94 |
+
const subCount = host.split(".").length - 2;
|
| 95 |
+
if (subCount >= 3) { score += 10; signals.push(`${subCount} subdomains`); }
|
| 96 |
+
} catch {}
|
| 97 |
+
|
| 98 |
+
// URL length (5 pts)
|
| 99 |
+
if (url.length > 100) { score += 5; signals.push(`Long URL (${url.length} chars)`); }
|
| 100 |
+
|
| 101 |
+
// Hyphens (5 pts)
|
| 102 |
+
try {
|
| 103 |
+
const host = new URL(url).hostname;
|
| 104 |
+
const hyphens = (host.match(/-/g) || []).length;
|
| 105 |
+
if (hyphens >= 3) { score += 5; signals.push(`${hyphens} hyphens in domain`); }
|
| 106 |
+
} catch {}
|
| 107 |
+
|
| 108 |
+
// Non-standard port (5 pts)
|
| 109 |
+
try {
|
| 110 |
+
const port = new URL(url).port;
|
| 111 |
+
if (port && port !== "80" && port !== "443") {
|
| 112 |
+
score += 5; signals.push(`Non-standard port :${port}`);
|
| 113 |
+
}
|
| 114 |
+
} catch {}
|
| 115 |
+
|
| 116 |
+
return { score: Math.min(score, 100), signals };
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// ββ URL Cache ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
+
function getCached(url) {
|
| 121 |
+
const entry = urlCache.get(url);
|
| 122 |
+
if (!entry) return null;
|
| 123 |
+
if (Date.now() - entry.ts > CACHE_TTL_MS) { urlCache.delete(url); return null; }
|
| 124 |
+
return entry.result;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
function setCache(url, result) {
|
| 128 |
+
urlCache.set(url, { result, ts: Date.now() });
|
| 129 |
+
if (urlCache.size > 500) {
|
| 130 |
+
const firstKey = urlCache.keys().next().value;
|
| 131 |
+
urlCache.delete(firstKey);
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// ββ Badge ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
function setBadge(tabId, status, text) {
|
| 137 |
+
const colors = {
|
| 138 |
+
safe: "#22C55E", blocked: "#EF4444", warn: "#F59E0B",
|
| 139 |
+
loading: "#534AB7", none: "#888888"
|
| 140 |
+
};
|
| 141 |
+
chrome.action.setBadgeBackgroundColor({ color: colors[status] || colors.none, tabId });
|
| 142 |
+
chrome.action.setBadgeText({ text: text || "", tabId });
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
// ββ Backend fetch with retry βββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
+
async function fetchBackend(url, payload, retryCount = 1) {
|
| 147 |
+
try {
|
| 148 |
+
const controller = new AbortController();
|
| 149 |
+
const timeout = setTimeout(() => controller.abort(), 15000);
|
| 150 |
+
const response = await fetch(url, {
|
| 151 |
+
method: "POST",
|
| 152 |
+
headers: { "Content-Type": "application/json" },
|
| 153 |
+
body: JSON.stringify(payload),
|
| 154 |
+
signal: controller.signal,
|
| 155 |
+
});
|
| 156 |
+
clearTimeout(timeout);
|
| 157 |
+
if (!response.ok) throw new Error(`Server ${response.status}`);
|
| 158 |
+
return await response.json();
|
| 159 |
+
} catch (err) {
|
| 160 |
+
if (retryCount > 0) {
|
| 161 |
+
await new Promise(r => setTimeout(r, 2000));
|
| 162 |
+
return fetchBackend(url, payload, retryCount - 1);
|
| 163 |
+
}
|
| 164 |
+
throw err;
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// ββ SHA256 hash ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
async function sha256(text) {
|
| 170 |
+
const encoded = new TextEncoder().encode(text);
|
| 171 |
+
const hash = await crypto.subtle.digest("SHA-256", encoded);
|
| 172 |
+
return Array.from(new Uint8Array(hash)).map(b => b.toString(16).padStart(2, "0")).join("");
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// ββ Storage helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 176 |
+
async function getStorage(keys) {
|
| 177 |
+
return new Promise(resolve => chrome.storage.local.get(keys, resolve));
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
async function setStorage(data) {
|
| 181 |
+
return new Promise(resolve => chrome.storage.local.set(data, resolve));
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
async function getQueue() {
|
| 185 |
+
const data = await getStorage(["phishguard_feedback_queue"]);
|
| 186 |
+
return data.phishguard_feedback_queue || [];
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
async function setQueue(queue) {
|
| 190 |
+
// FIFO eviction
|
| 191 |
+
if (queue.length > MAX_QUEUE_SIZE) {
|
| 192 |
+
queue = queue.slice(queue.length - MAX_QUEUE_SIZE);
|
| 193 |
+
}
|
| 194 |
+
await setStorage({ phishguard_feedback_queue: queue });
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
// ββ ON INSTALL βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 198 |
+
chrome.runtime.onInstalled.addListener(async () => {
|
| 199 |
+
const sessionId = crypto.randomUUID();
|
| 200 |
+
await setStorage({
|
| 201 |
+
session_id: sessionId,
|
| 202 |
+
scan_count: 0,
|
| 203 |
+
feedback_count: 0,
|
| 204 |
+
last_retrain_ts: null,
|
| 205 |
+
model_version: 0,
|
| 206 |
+
phishguard_feedback_queue: [],
|
| 207 |
+
});
|
| 208 |
+
|
| 209 |
+
// 24-hour retraining alarm
|
| 210 |
+
chrome.alarms.create("retrain_alarm", { periodInMinutes: 1440 });
|
| 211 |
+
// 30-minute model polling alarm
|
| 212 |
+
chrome.alarms.create("model_poll_alarm", { periodInMinutes: 30 });
|
| 213 |
+
|
| 214 |
+
console.log("[PhishGuard] Installed. Session:", sessionId);
|
| 215 |
+
});
|
| 216 |
+
|
| 217 |
+
// ββ ALARM HANDLERS βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 218 |
+
chrome.alarms.onAlarm.addListener(async (alarm) => {
|
| 219 |
+
if (alarm.name === "retrain_alarm") {
|
| 220 |
+
console.log("[PhishGuard] Retrain alarm fired");
|
| 221 |
+
await checkRetrain("timer");
|
| 222 |
+
}
|
| 223 |
+
if (alarm.name === "model_poll_alarm") {
|
| 224 |
+
await pollModelVersion();
|
| 225 |
+
}
|
| 226 |
+
});
|
| 227 |
+
|
| 228 |
+
// ββ MAIN URL LISTENER ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 229 |
+
chrome.webNavigation.onCompleted.addListener(async (details) => {
|
| 230 |
+
if (details.frameId !== 0) return;
|
| 231 |
+
const url = details.url;
|
| 232 |
+
if (!url.startsWith("http")) return;
|
| 233 |
+
|
| 234 |
+
const tabId = details.tabId;
|
| 235 |
+
const domain = getRootDomain(url);
|
| 236 |
+
if (!domain) return;
|
| 237 |
+
|
| 238 |
+
setBadge(tabId, "loading", "β¦");
|
| 239 |
+
|
| 240 |
+
// TIER 1: Whitelist
|
| 241 |
+
if (WHITELIST.has(domain)) {
|
| 242 |
+
const result = {
|
| 243 |
+
url, status: "safe", tier: 1, method: "whitelist",
|
| 244 |
+
confidence: 0, heuristic_score: 0, signals: []
|
| 245 |
+
};
|
| 246 |
+
await setStorage({ lastResult: result });
|
| 247 |
+
tabResultCache.set(tabId, result);
|
| 248 |
+
setBadge(tabId, "safe", "β");
|
| 249 |
+
return;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
// Cache check
|
| 253 |
+
const cached = getCached(url);
|
| 254 |
+
if (cached) {
|
| 255 |
+
await setStorage({ lastResult: cached });
|
| 256 |
+
tabResultCache.set(tabId, cached);
|
| 257 |
+
setBadge(tabId, cached.status, cached.status === "blocked" ? "!" : "β");
|
| 258 |
+
if (cached.status === "blocked") blockPage(tabId, url, cached);
|
| 259 |
+
return;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
// TIER 2: Heuristic
|
| 263 |
+
const hResult = heuristicScore(url);
|
| 264 |
+
|
| 265 |
+
if (hResult.score >= 80) {
|
| 266 |
+
const result = {
|
| 267 |
+
url, status: "blocked", tier: 2, method: "heuristic",
|
| 268 |
+
confidence: hResult.score / 100, heuristic_score: hResult.score,
|
| 269 |
+
signals: hResult.signals, is_phishing: true
|
| 270 |
+
};
|
| 271 |
+
setCache(url, result);
|
| 272 |
+
await setStorage({ lastResult: result });
|
| 273 |
+
tabResultCache.set(tabId, result);
|
| 274 |
+
setBadge(tabId, "blocked", "!");
|
| 275 |
+
blockPage(tabId, url, result);
|
| 276 |
+
await storeFeedbackRecord(url, result);
|
| 277 |
+
await incrementScanCount();
|
| 278 |
+
return;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
// TIER 3+4: Send to backend
|
| 282 |
+
const signals = pageSignals.get(tabId) || {};
|
| 283 |
+
try {
|
| 284 |
+
const apiResult = await fetchBackend(ANALYZE_URL, {
|
| 285 |
+
url,
|
| 286 |
+
heuristic_score: hResult.score,
|
| 287 |
+
page_title: signals.title || "",
|
| 288 |
+
page_snippet: signals.snippet || "",
|
| 289 |
+
});
|
| 290 |
+
|
| 291 |
+
const finalResult = {
|
| 292 |
+
url,
|
| 293 |
+
status: apiResult.is_phishing ? "blocked" : "safe",
|
| 294 |
+
tier: apiResult.tier || 3,
|
| 295 |
+
method: apiResult.method || "ensemble",
|
| 296 |
+
confidence: apiResult.confidence || 0,
|
| 297 |
+
heuristic_score: apiResult.heuristic_score || hResult.score,
|
| 298 |
+
signals: apiResult.signals || hResult.signals,
|
| 299 |
+
is_phishing: apiResult.is_phishing,
|
| 300 |
+
details: apiResult.details || {},
|
| 301 |
+
};
|
| 302 |
+
|
| 303 |
+
setCache(url, finalResult);
|
| 304 |
+
await setStorage({ lastResult: finalResult });
|
| 305 |
+
tabResultCache.set(tabId, finalResult);
|
| 306 |
+
|
| 307 |
+
if (finalResult.status === "blocked") {
|
| 308 |
+
setBadge(tabId, "blocked", "!");
|
| 309 |
+
blockPage(tabId, url, finalResult);
|
| 310 |
+
} else if (finalResult.confidence >= 0.4) {
|
| 311 |
+
setBadge(tabId, "warn", "?");
|
| 312 |
+
} else {
|
| 313 |
+
setBadge(tabId, "safe", "β");
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
await storeFeedbackRecord(url, finalResult);
|
| 317 |
+
|
| 318 |
+
} catch (err) {
|
| 319 |
+
console.log("[PhishGuard] Backend unreachable:", err.message);
|
| 320 |
+
const fallback = {
|
| 321 |
+
url,
|
| 322 |
+
status: hResult.score >= 50 ? "blocked" : "safe",
|
| 323 |
+
tier: 2,
|
| 324 |
+
method: "heuristic-fallback",
|
| 325 |
+
confidence: hResult.score / 100,
|
| 326 |
+
heuristic_score: hResult.score,
|
| 327 |
+
signals: hResult.signals,
|
| 328 |
+
is_phishing: hResult.score >= 50,
|
| 329 |
+
details: { backend_error: err.message },
|
| 330 |
+
};
|
| 331 |
+
setCache(url, fallback);
|
| 332 |
+
await setStorage({ lastResult: fallback });
|
| 333 |
+
tabResultCache.set(tabId, fallback);
|
| 334 |
+
|
| 335 |
+
if (hResult.score >= 50) {
|
| 336 |
+
setBadge(tabId, "blocked", "!");
|
| 337 |
+
blockPage(tabId, url, fallback);
|
| 338 |
+
} else if (hResult.score >= 30) {
|
| 339 |
+
setBadge(tabId, "warn", "?");
|
| 340 |
+
} else {
|
| 341 |
+
setBadge(tabId, "none", "");
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
await storeFeedbackRecord(url, fallback);
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
await incrementScanCount();
|
| 348 |
+
await checkRetrain("count");
|
| 349 |
+
pageSignals.delete(tabId);
|
| 350 |
+
|
| 351 |
+
}, { url: [{ schemes: ["http", "https"] }] });
|
| 352 |
+
|
| 353 |
+
// ββ Feedback Record Storage ββββββββββββββββββββββββββββββββββββββββββ
|
| 354 |
+
async function storeFeedbackRecord(url, result) {
|
| 355 |
+
const urlHash = await sha256(url);
|
| 356 |
+
const record = {
|
| 357 |
+
url,
|
| 358 |
+
verdict: result.is_phishing ? "phishing" : "safe",
|
| 359 |
+
confidence: result.confidence || 0,
|
| 360 |
+
tier_used: result.tier || 0,
|
| 361 |
+
heuristic_score: result.heuristic_score || 0,
|
| 362 |
+
signals: result.signals || [],
|
| 363 |
+
user_feedback: null,
|
| 364 |
+
timestamp: new Date().toISOString(),
|
| 365 |
+
feedback_ts: null,
|
| 366 |
+
url_hash: urlHash,
|
| 367 |
+
session_id: (await getStorage(["session_id"])).session_id || "",
|
| 368 |
+
};
|
| 369 |
+
|
| 370 |
+
const queue = await getQueue();
|
| 371 |
+
queue.push(record);
|
| 372 |
+
await setQueue(queue);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
async function incrementScanCount() {
|
| 376 |
+
const data = await getStorage(["scan_count"]);
|
| 377 |
+
await setStorage({ scan_count: (data.scan_count || 0) + 1 });
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
// ββ Block Page βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 381 |
+
function blockPage(tabId, url, result) {
|
| 382 |
+
chrome.storage.local.set({ lastResult: { ...result, status: "blocked" } });
|
| 383 |
+
tabResultCache.set(tabId, result);
|
| 384 |
+
const score = Math.round((result.confidence || 0) * 100);
|
| 385 |
+
chrome.tabs.update(tabId, {
|
| 386 |
+
url: chrome.runtime.getURL("popup.html") +
|
| 387 |
+
"?blocked=1&url=" + encodeURIComponent(url) +
|
| 388 |
+
"&score=" + score +
|
| 389 |
+
"&method=" + encodeURIComponent(result.method || "")
|
| 390 |
+
});
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
// ββ Retrain Check ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 394 |
+
async function checkRetrain(trigger = "count") {
|
| 395 |
+
const queue = await getQueue();
|
| 396 |
+
const labeled = queue.filter(r => r.user_feedback !== null);
|
| 397 |
+
|
| 398 |
+
if (labeled.length < MIN_LABELED_SAMPLES) {
|
| 399 |
+
console.log(`[PhishGuard] Not enough labeled samples (${labeled.length}/${MIN_LABELED_SAMPLES})`);
|
| 400 |
+
return;
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
const data = await getStorage(["scan_count"]);
|
| 404 |
+
const scanCount = data.scan_count || 0;
|
| 405 |
+
|
| 406 |
+
if (trigger === "timer" || scanCount >= RETRAIN_URL_THRESHOLD) {
|
| 407 |
+
console.log(`[PhishGuard] Triggering retrain: trigger=${trigger}, labeled=${labeled.length}, scans=${scanCount}`);
|
| 408 |
+
await sendRetrainRequest(labeled, trigger);
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
async function sendRetrainRequest(samples, trigger) {
|
| 413 |
+
const data = await getStorage(["session_id"]);
|
| 414 |
+
try {
|
| 415 |
+
const result = await fetchBackend(RETRAIN_URL, {
|
| 416 |
+
samples,
|
| 417 |
+
trigger,
|
| 418 |
+
session_id: data.session_id || "",
|
| 419 |
+
extension_version: "3.0",
|
| 420 |
+
});
|
| 421 |
+
|
| 422 |
+
if (result.status === "success") {
|
| 423 |
+
// Reset counters
|
| 424 |
+
await setStorage({
|
| 425 |
+
scan_count: 0,
|
| 426 |
+
feedback_count: 0,
|
| 427 |
+
last_retrain_ts: new Date().toISOString(),
|
| 428 |
+
});
|
| 429 |
+
|
| 430 |
+
// Remove sent records from queue
|
| 431 |
+
const queue = await getQueue();
|
| 432 |
+
const sentHashes = new Set(samples.map(s => s.url_hash));
|
| 433 |
+
const remaining = queue.filter(r => !sentHashes.has(r.url_hash));
|
| 434 |
+
await setQueue(remaining);
|
| 435 |
+
|
| 436 |
+
// Show notification
|
| 437 |
+
showRetrainNotification(result.accuracy_delta || {});
|
| 438 |
+
|
| 439 |
+
console.log("[PhishGuard] Retrain success:", result);
|
| 440 |
+
}
|
| 441 |
+
} catch (err) {
|
| 442 |
+
console.error("[PhishGuard] Retrain request failed:", err.message);
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
function showRetrainNotification(delta) {
|
| 447 |
+
const bertDelta = delta.bert ? `BERT: ${(delta.bert * 100).toFixed(1)}%` : "";
|
| 448 |
+
const gnnDelta = delta.gnn ? `GNN: ${(delta.gnn * 100).toFixed(1)}%` : "";
|
| 449 |
+
const parts = [bertDelta, gnnDelta].filter(Boolean).join(", ");
|
| 450 |
+
|
| 451 |
+
chrome.notifications.create("retrain_complete", {
|
| 452 |
+
type: "basic",
|
| 453 |
+
iconUrl: "icons/icon48.png",
|
| 454 |
+
title: "PhishGuard AI Updated",
|
| 455 |
+
message: parts ? `Models improved! ${parts} accuracy from your feedback` :
|
| 456 |
+
"Models updated with your feedback",
|
| 457 |
+
});
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
// ββ Model Version Polling ββββββββββββββββββββββββββββββββββββββββββββ
|
| 461 |
+
async function pollModelVersion() {
|
| 462 |
+
try {
|
| 463 |
+
const controller = new AbortController();
|
| 464 |
+
const timeout = setTimeout(() => controller.abort(), 10000);
|
| 465 |
+
const resp = await fetch(MODEL_VERSION_URL, { signal: controller.signal });
|
| 466 |
+
clearTimeout(timeout);
|
| 467 |
+
|
| 468 |
+
if (!resp.ok) return;
|
| 469 |
+
const info = await resp.json();
|
| 470 |
+
|
| 471 |
+
const stored = await getStorage(["model_version"]);
|
| 472 |
+
if (info.version > (stored.model_version || 0)) {
|
| 473 |
+
await setStorage({ model_version: info.version });
|
| 474 |
+
// Clear URL cache (stale results)
|
| 475 |
+
urlCache.clear();
|
| 476 |
+
|
| 477 |
+
chrome.notifications.create("model_updated", {
|
| 478 |
+
type: "basic",
|
| 479 |
+
iconUrl: "icons/icon48.png",
|
| 480 |
+
title: "PhishGuard Models Updated",
|
| 481 |
+
message: `Model v${info.version} is now active`,
|
| 482 |
+
});
|
| 483 |
+
}
|
| 484 |
+
} catch (err) {
|
| 485 |
+
// Silently fail β model polling is best-effort
|
| 486 |
+
}
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
// ββ Message Handler ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 490 |
+
chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => {
|
| 491 |
+
// Page signals from content.js
|
| 492 |
+
if (msg.type === "page_signals") {
|
| 493 |
+
if (sender.tab) {
|
| 494 |
+
pageSignals.set(sender.tab.id, {
|
| 495 |
+
title: msg.title || "",
|
| 496 |
+
snippet: msg.snippet || "",
|
| 497 |
+
signals: msg.signals || [],
|
| 498 |
+
});
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
// Submit feedback from popup.js / content.js
|
| 503 |
+
if (msg.type === "submit_feedback") {
|
| 504 |
+
(async () => {
|
| 505 |
+
const queue = await getQueue();
|
| 506 |
+
const idx = queue.findIndex(r => r.url_hash === msg.url_hash);
|
| 507 |
+
if (idx >= 0) {
|
| 508 |
+
queue[idx].user_feedback = msg.feedback; // "correct" or "incorrect"
|
| 509 |
+
queue[idx].feedback_ts = new Date().toISOString();
|
| 510 |
+
await setQueue(queue);
|
| 511 |
+
|
| 512 |
+
// Increment feedback count
|
| 513 |
+
const data = await getStorage(["feedback_count"]);
|
| 514 |
+
await setStorage({ feedback_count: (data.feedback_count || 0) + 1 });
|
| 515 |
+
|
| 516 |
+
// Check if we should trigger retraining
|
| 517 |
+
await checkRetrain("count");
|
| 518 |
+
|
| 519 |
+
sendResponse({ success: true });
|
| 520 |
+
} else {
|
| 521 |
+
sendResponse({ success: false, error: "Record not found" });
|
| 522 |
+
}
|
| 523 |
+
})();
|
| 524 |
+
return true; // async response
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
// Get status for popup
|
| 528 |
+
if (msg.type === "get_status") {
|
| 529 |
+
(async () => {
|
| 530 |
+
const data = await getStorage([
|
| 531 |
+
"scan_count", "feedback_count", "last_retrain_ts",
|
| 532 |
+
"model_version", "session_id"
|
| 533 |
+
]);
|
| 534 |
+
const queue = await getQueue();
|
| 535 |
+
const labeled = queue.filter(r => r.user_feedback !== null).length;
|
| 536 |
+
|
| 537 |
+
const lastRetrain = data.last_retrain_ts ? new Date(data.last_retrain_ts) : null;
|
| 538 |
+
const now = Date.now();
|
| 539 |
+
const nextTimerMs = lastRetrain
|
| 540 |
+
? Math.max(0, (24 * 60 * 60 * 1000) - (now - lastRetrain.getTime()))
|
| 541 |
+
: 24 * 60 * 60 * 1000;
|
| 542 |
+
|
| 543 |
+
sendResponse({
|
| 544 |
+
scan_count: data.scan_count || 0,
|
| 545 |
+
feedback_count: data.feedback_count || 0,
|
| 546 |
+
labeled_count: labeled,
|
| 547 |
+
last_retrain_ts: data.last_retrain_ts,
|
| 548 |
+
model_version: data.model_version || 0,
|
| 549 |
+
next_retrain_urls_remaining: Math.max(0, RETRAIN_URL_THRESHOLD - (data.scan_count || 0)),
|
| 550 |
+
next_retrain_time_remaining_ms: nextTimerMs,
|
| 551 |
+
min_labeled_needed: Math.max(0, MIN_LABELED_SAMPLES - labeled),
|
| 552 |
+
});
|
| 553 |
+
})();
|
| 554 |
+
return true;
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
// Per-tab result cache query from popup
|
| 558 |
+
if (msg.type === "get_tab_result") {
|
| 559 |
+
const result = tabResultCache.get(msg.tabId);
|
| 560 |
+
sendResponse({ result: result || null });
|
| 561 |
+
return false;
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
// User override (Proceed Anyway)
|
| 565 |
+
if (msg.type === "whitelist_url") {
|
| 566 |
+
const override = {
|
| 567 |
+
url: msg.url, status: "safe", tier: 0,
|
| 568 |
+
method: "user-override", confidence: 0
|
| 569 |
+
};
|
| 570 |
+
setCache(msg.url, override);
|
| 571 |
+
chrome.storage.local.set({ lastResult: override });
|
| 572 |
+
sendResponse({ success: true });
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
// Gmail scanner bridge
|
| 576 |
+
if (msg.action === "analyzeEmail") {
|
| 577 |
+
const emailURL = ANALYZE_URL.replace(/\/analyze\/?$/, "/analyze/email");
|
| 578 |
+
fetch(emailURL, {
|
| 579 |
+
method: "POST",
|
| 580 |
+
headers: { "Content-Type": "application/json" },
|
| 581 |
+
body: JSON.stringify(msg.data),
|
| 582 |
+
})
|
| 583 |
+
.then(r => r.ok ? r.json() : Promise.reject(new Error(`${r.status}`)))
|
| 584 |
+
.then(data => sendResponse(data))
|
| 585 |
+
.catch(err => sendResponse({
|
| 586 |
+
status: "error",
|
| 587 |
+
analysis: { isPhishing: false, probability: 0, reason: "Backend unreachable" }
|
| 588 |
+
}));
|
| 589 |
+
return true;
|
| 590 |
+
}
|
| 591 |
+
});
|
| 592 |
+
|
| 593 |
+
// ββ Tab cleanup ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 594 |
+
chrome.tabs.onRemoved.addListener(tabId => {
|
| 595 |
+
pageSignals.delete(tabId);
|
| 596 |
+
tabResultCache.delete(tabId);
|
| 597 |
+
});
|
| 598 |
+
|
| 599 |
+
chrome.tabs.onUpdated.addListener((tabId, changeInfo) => {
|
| 600 |
+
if (changeInfo.url) {
|
| 601 |
+
tabResultCache.delete(tabId);
|
| 602 |
+
setBadge(tabId, "none", "");
|
| 603 |
+
}
|
| 604 |
+
});
|
bert_analyzer.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - bert_analyzer.py
|
| 3 |
+
# Tier 3a: BERT NLP Phishing Classifier
|
| 4 |
+
#
|
| 5 |
+
# Model: ealvaradob/bert-finetuned-phishing (HuggingFace Hub)
|
| 6 |
+
# Tokenization: split on [-./=?&_~%@] to preserve homoglyphs
|
| 7 |
+
# Input: "URL: {tokenized_url}. Title: {title}. Content: {snippet}"
|
| 8 |
+
# Output: P_bert β [0,1]
|
| 9 |
+
# Supports: load, predict, fine-tune, incremental_update, save/load
|
| 10 |
+
# ============================================================
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import re
|
| 15 |
+
import math
|
| 16 |
+
import logging
|
| 17 |
+
import threading
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Tuple, Optional, Dict
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("phishguard.bert")
|
| 22 |
+
|
| 23 |
+
# ββ Model state ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
_classifier = None
|
| 25 |
+
_tokenizer = None
|
| 26 |
+
_model = None
|
| 27 |
+
_use_bert: bool = False
|
| 28 |
+
_bert_load_attempted: bool = False
|
| 29 |
+
_bert_lock = threading.Lock()
|
| 30 |
+
|
| 31 |
+
# Check if transformers library is installed
|
| 32 |
+
_transformers_available: bool = False
|
| 33 |
+
try:
|
| 34 |
+
import transformers as _tf_module
|
| 35 |
+
_transformers_available = True
|
| 36 |
+
logger.info("transformers library found β BERT will lazy-load on first call")
|
| 37 |
+
except ImportError:
|
| 38 |
+
logger.info("transformers not installed β using keyword NLP fallback")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ββ Phishing pattern databases (for keyword fallback) ββββββββββββββββ
|
| 42 |
+
PHISHING_TERMS = [
|
| 43 |
+
"verify your account", "suspended", "click here immediately",
|
| 44 |
+
"unusual activity", "confirm your identity", "limited time",
|
| 45 |
+
"your password has been", "unauthorized access", "act now",
|
| 46 |
+
"secure your account", "login credentials", "reset password immediately",
|
| 47 |
+
"your account will be", "verify your identity", "we noticed suspicious",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
PHISHING_KEYWORDS = [
|
| 51 |
+
"login", "secure", "verify", "account", "update", "confirm",
|
| 52 |
+
"banking", "paypal", "signin", "password", "suspend", "alert",
|
| 53 |
+
"restore", "unusual", "limited", "expire", "urgent", "immediately",
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
BRAND_NAMES = [
|
| 57 |
+
"paypal", "google", "apple", "microsoft", "amazon", "netflix",
|
| 58 |
+
"facebook", "instagram", "twitter", "linkedin", "chase", "wells",
|
| 59 |
+
"bankofamerica", "citibank", "usps", "fedex", "ebay",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class BERTPhishingClassifier:
|
| 64 |
+
"""
|
| 65 |
+
BERT-based phishing text classifier.
|
| 66 |
+
Wraps HuggingFace model with URL-aware tokenization.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
DEFAULT_MODEL = "ealvaradob/bert-finetuned-phishing"
|
| 70 |
+
FALLBACK_MODEL = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
|
| 71 |
+
|
| 72 |
+
def __init__(self, model_name: Optional[str] = None) -> None:
|
| 73 |
+
self.model_name: str = model_name or self.DEFAULT_MODEL
|
| 74 |
+
self._pipeline = None
|
| 75 |
+
self._tokenizer = None
|
| 76 |
+
self._model = None
|
| 77 |
+
self._loaded: bool = False
|
| 78 |
+
self._lock = threading.Lock()
|
| 79 |
+
self._re_url_split = re.compile(r"[-./=?&_~%@:]+")
|
| 80 |
+
|
| 81 |
+
def load_model(self) -> None:
|
| 82 |
+
"""Load BERT model from HuggingFace Hub with cache fallback."""
|
| 83 |
+
if self._loaded:
|
| 84 |
+
return
|
| 85 |
+
with self._lock:
|
| 86 |
+
if self._loaded:
|
| 87 |
+
return
|
| 88 |
+
if not _transformers_available:
|
| 89 |
+
logger.warning("transformers not available, BERT disabled")
|
| 90 |
+
return
|
| 91 |
+
try:
|
| 92 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
| 93 |
+
|
| 94 |
+
# Try primary model, fall back to smaller model
|
| 95 |
+
for model_id in [self.model_name, self.FALLBACK_MODEL]:
|
| 96 |
+
try:
|
| 97 |
+
self._pipeline = pipeline(
|
| 98 |
+
"text-classification",
|
| 99 |
+
model=model_id,
|
| 100 |
+
truncation=True,
|
| 101 |
+
max_length=512,
|
| 102 |
+
device=-1,
|
| 103 |
+
)
|
| 104 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 105 |
+
self._model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
| 106 |
+
self.model_name = model_id
|
| 107 |
+
self._loaded = True
|
| 108 |
+
logger.info(f"BERT model loaded: {model_id}")
|
| 109 |
+
return
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.warning(f"Failed to load {model_id}: {e}")
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
logger.error("All BERT model candidates failed")
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"BERT initialization failed: {e}")
|
| 118 |
+
|
| 119 |
+
def tokenize_url(self, url: str) -> str:
|
| 120 |
+
"""
|
| 121 |
+
Split URL on [-./=?&_~%@:] to preserve homoglyphs.
|
| 122 |
+
Example: "paypa1-l0gin.xyz/verify" β "paypa1 l0gin xyz verify"
|
| 123 |
+
"""
|
| 124 |
+
text = url.replace("https://", "").replace("http://", "")
|
| 125 |
+
tokens = self._re_url_split.split(text)
|
| 126 |
+
return " ".join(t for t in tokens if t)
|
| 127 |
+
|
| 128 |
+
def predict(self, url: str, title: str = "", snippet: str = "") -> float:
|
| 129 |
+
"""
|
| 130 |
+
Predict phishing probability for a URL + page context.
|
| 131 |
+
Returns P_bert β [0,1].
|
| 132 |
+
"""
|
| 133 |
+
self.load_model()
|
| 134 |
+
|
| 135 |
+
if self._loaded and self._pipeline is not None:
|
| 136 |
+
return self._predict_bert(url, title, snippet)
|
| 137 |
+
return self._predict_keyword(url, title, snippet)
|
| 138 |
+
|
| 139 |
+
def _predict_bert(self, url: str, title: str, snippet: str) -> float:
|
| 140 |
+
"""BERT model prediction path."""
|
| 141 |
+
url_text = self.tokenize_url(url)
|
| 142 |
+
combined = f"URL: {url_text}. Title: {title}. Content: {snippet[:300]}"
|
| 143 |
+
|
| 144 |
+
result = self._pipeline(combined[:512])[0]
|
| 145 |
+
label = result["label"].upper()
|
| 146 |
+
confidence = result["score"]
|
| 147 |
+
|
| 148 |
+
# Map label to phishing probability
|
| 149 |
+
if any(kw in label for kw in ["SPAM", "PHISH", "MALICIOUS", "LABEL_1", "1"]):
|
| 150 |
+
raw_prob = confidence
|
| 151 |
+
else:
|
| 152 |
+
raw_prob = 1.0 - confidence
|
| 153 |
+
|
| 154 |
+
# Boost with keyword signals
|
| 155 |
+
text_lower = combined.lower()
|
| 156 |
+
phrase_hits = sum(1 for p in PHISHING_TERMS if p in text_lower)
|
| 157 |
+
adjusted = min(raw_prob + (phrase_hits * 0.05), 1.0)
|
| 158 |
+
|
| 159 |
+
return round(adjusted, 4)
|
| 160 |
+
|
| 161 |
+
def _predict_keyword(self, url: str, title: str, snippet: str) -> float:
|
| 162 |
+
"""Keyword-based fallback when BERT is unavailable."""
|
| 163 |
+
combined = f"{url} {title} {snippet}".lower()
|
| 164 |
+
url_lower = url.lower()
|
| 165 |
+
score = 0.0
|
| 166 |
+
|
| 167 |
+
# Keyword hits in URL
|
| 168 |
+
kw_hits = sum(1 for kw in PHISHING_KEYWORDS if kw in url_lower)
|
| 169 |
+
score += min(kw_hits * 0.08, 0.40)
|
| 170 |
+
|
| 171 |
+
# Phrase matches in content
|
| 172 |
+
phrase_hits = sum(1 for p in PHISHING_TERMS if p in combined)
|
| 173 |
+
score += min(phrase_hits * 0.12, 0.48)
|
| 174 |
+
|
| 175 |
+
# Brand spoofing
|
| 176 |
+
for brand in BRAND_NAMES:
|
| 177 |
+
if brand in url_lower:
|
| 178 |
+
if f"{brand}.com" not in url_lower:
|
| 179 |
+
score += 0.20
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
# IP as hostname
|
| 183 |
+
if re.match(r"https?://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", url):
|
| 184 |
+
score += 0.20
|
| 185 |
+
|
| 186 |
+
# Shannon entropy of hostname
|
| 187 |
+
try:
|
| 188 |
+
from urllib.parse import urlparse
|
| 189 |
+
host = urlparse(url if "://" in url else f"http://{url}").hostname or ""
|
| 190 |
+
if host:
|
| 191 |
+
length = len(host)
|
| 192 |
+
freq: Dict[str, int] = {}
|
| 193 |
+
for c in host:
|
| 194 |
+
freq[c] = freq.get(c, 0) + 1
|
| 195 |
+
entropy = -sum(
|
| 196 |
+
(cnt / length) * math.log2(cnt / length) for cnt in freq.values()
|
| 197 |
+
)
|
| 198 |
+
if entropy > 3.5:
|
| 199 |
+
score += 0.10
|
| 200 |
+
except Exception:
|
| 201 |
+
pass
|
| 202 |
+
|
| 203 |
+
return round(min(score, 1.0), 4)
|
| 204 |
+
|
| 205 |
+
def incremental_update(
|
| 206 |
+
self,
|
| 207 |
+
samples: List[Tuple[str, int]],
|
| 208 |
+
lr: float = 1e-5,
|
| 209 |
+
epochs: int = 1,
|
| 210 |
+
label_smoothing: float = 0.1,
|
| 211 |
+
) -> Optional[float]:
|
| 212 |
+
"""
|
| 213 |
+
Incremental update: unfreeze last 2 transformer layers only.
|
| 214 |
+
Returns accuracy_delta (float) or None if update failed.
|
| 215 |
+
|
| 216 |
+
samples: list of (url, label) where label is 0 or 1
|
| 217 |
+
"""
|
| 218 |
+
if not self._loaded or self._model is None or self._tokenizer is None:
|
| 219 |
+
logger.warning("BERT not loaded, cannot incrementally update")
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
if len(samples) < 5:
|
| 223 |
+
logger.warning(f"Too few samples ({len(samples)}) for BERT update")
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
import torch
|
| 228 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 229 |
+
from torch.optim import AdamW
|
| 230 |
+
|
| 231 |
+
device = torch.device("cpu")
|
| 232 |
+
model = self._model.to(device)
|
| 233 |
+
|
| 234 |
+
# Freeze all layers
|
| 235 |
+
for param in model.parameters():
|
| 236 |
+
param.requires_grad = False
|
| 237 |
+
|
| 238 |
+
# Unfreeze last 2 transformer layers + classifier
|
| 239 |
+
if hasattr(model, "bert"):
|
| 240 |
+
encoder_layers = model.bert.encoder.layer
|
| 241 |
+
for layer in encoder_layers[-2:]:
|
| 242 |
+
for param in layer.parameters():
|
| 243 |
+
param.requires_grad = True
|
| 244 |
+
if hasattr(model, "classifier"):
|
| 245 |
+
for param in model.classifier.parameters():
|
| 246 |
+
param.requires_grad = True
|
| 247 |
+
|
| 248 |
+
# Prepare data
|
| 249 |
+
texts = [self.tokenize_url(url) for url, _ in samples]
|
| 250 |
+
labels = [label for _, label in samples]
|
| 251 |
+
|
| 252 |
+
encodings = self._tokenizer(
|
| 253 |
+
texts, truncation=True, padding=True, max_length=512,
|
| 254 |
+
return_tensors="pt"
|
| 255 |
+
)
|
| 256 |
+
label_tensor = torch.tensor(labels, dtype=torch.long).to(device)
|
| 257 |
+
|
| 258 |
+
dataset = TensorDataset(
|
| 259 |
+
encodings["input_ids"].to(device),
|
| 260 |
+
encodings["attention_mask"].to(device),
|
| 261 |
+
label_tensor,
|
| 262 |
+
)
|
| 263 |
+
batch_size = min(len(samples), 16)
|
| 264 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 265 |
+
|
| 266 |
+
# Pre-update accuracy
|
| 267 |
+
model.eval()
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
pre_correct = 0
|
| 270 |
+
for batch in loader:
|
| 271 |
+
ids, mask, labs = batch
|
| 272 |
+
outputs = model(input_ids=ids, attention_mask=mask)
|
| 273 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 274 |
+
pre_correct += (preds == labs).sum().item()
|
| 275 |
+
pre_acc = pre_correct / len(samples)
|
| 276 |
+
|
| 277 |
+
# Train
|
| 278 |
+
optimizer = AdamW(
|
| 279 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 280 |
+
lr=lr,
|
| 281 |
+
)
|
| 282 |
+
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
| 283 |
+
|
| 284 |
+
model.train()
|
| 285 |
+
for epoch in range(epochs):
|
| 286 |
+
total_loss = 0.0
|
| 287 |
+
for batch in loader:
|
| 288 |
+
ids, mask, labs = batch
|
| 289 |
+
optimizer.zero_grad()
|
| 290 |
+
outputs = model(input_ids=ids, attention_mask=mask)
|
| 291 |
+
loss = loss_fn(outputs.logits, labs)
|
| 292 |
+
loss.backward()
|
| 293 |
+
optimizer.step()
|
| 294 |
+
total_loss += loss.item()
|
| 295 |
+
logger.info(f"BERT incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}")
|
| 296 |
+
|
| 297 |
+
# Post-update accuracy
|
| 298 |
+
model.eval()
|
| 299 |
+
with torch.no_grad():
|
| 300 |
+
post_correct = 0
|
| 301 |
+
for batch in loader:
|
| 302 |
+
ids, mask, labs = batch
|
| 303 |
+
outputs = model(input_ids=ids, attention_mask=mask)
|
| 304 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 305 |
+
post_correct += (preds == labs).sum().item()
|
| 306 |
+
post_acc = post_correct / len(samples)
|
| 307 |
+
|
| 308 |
+
delta = post_acc - pre_acc
|
| 309 |
+
self._model = model
|
| 310 |
+
logger.info(f"BERT incremental update: {pre_acc:.4f} β {post_acc:.4f} (Ξ={delta:+.4f})")
|
| 311 |
+
return round(delta, 4)
|
| 312 |
+
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.error(f"BERT incremental update failed: {e}")
|
| 315 |
+
return None
|
| 316 |
+
|
| 317 |
+
def save(self, path: Path) -> None:
|
| 318 |
+
"""Save model and tokenizer to directory."""
|
| 319 |
+
if self._model and self._tokenizer:
|
| 320 |
+
path = Path(path)
|
| 321 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 322 |
+
self._model.save_pretrained(str(path))
|
| 323 |
+
self._tokenizer.save_pretrained(str(path))
|
| 324 |
+
logger.info(f"BERT model saved to {path}")
|
| 325 |
+
|
| 326 |
+
def load_local(self, path: Path) -> bool:
|
| 327 |
+
"""Load model from local directory."""
|
| 328 |
+
path = Path(path)
|
| 329 |
+
if not path.exists():
|
| 330 |
+
return False
|
| 331 |
+
try:
|
| 332 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
| 333 |
+
self._tokenizer = AutoTokenizer.from_pretrained(str(path))
|
| 334 |
+
self._model = AutoModelForSequenceClassification.from_pretrained(str(path))
|
| 335 |
+
self._pipeline = pipeline(
|
| 336 |
+
"text-classification",
|
| 337 |
+
model=self._model,
|
| 338 |
+
tokenizer=self._tokenizer,
|
| 339 |
+
truncation=True,
|
| 340 |
+
max_length=512,
|
| 341 |
+
device=-1,
|
| 342 |
+
)
|
| 343 |
+
self._loaded = True
|
| 344 |
+
logger.info(f"BERT model loaded from {path}")
|
| 345 |
+
return True
|
| 346 |
+
except Exception as e:
|
| 347 |
+
logger.error(f"BERT local load failed: {e}")
|
| 348 |
+
return False
|
| 349 |
+
|
| 350 |
+
@property
|
| 351 |
+
def is_loaded(self) -> bool:
|
| 352 |
+
return self._loaded
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# ββ Legacy compatibility βββββββββββββββββββββββββββββββββββββββββββββ
|
| 356 |
+
_default_classifier = BERTPhishingClassifier()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def analyze_text(url: str, page_title: str = "", page_snippet: str = "") -> dict:
|
| 360 |
+
"""Legacy wrapper for backward compatibility with main.py."""
|
| 361 |
+
prob = _default_classifier.predict(url, page_title, page_snippet)
|
| 362 |
+
return {
|
| 363 |
+
"bert_phishing_prob": prob,
|
| 364 |
+
"phrase_hits": 0,
|
| 365 |
+
"label": "BERT" if _default_classifier.is_loaded else "KEYWORD_NLP",
|
| 366 |
+
"confidence": prob,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def shannon_entropy(s: str) -> float:
|
| 371 |
+
"""Utility: measure randomness of a string."""
|
| 372 |
+
if not s:
|
| 373 |
+
return 0.0
|
| 374 |
+
prob = [s.count(c) / len(s) for c in set(s)]
|
| 375 |
+
return -sum(p * math.log2(p) for p in prob if p > 0)
|
bert_finetune.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - bert_finetune.py
|
| 3 |
+
# Full BERT fine-tuning script on PhishTank + TRANCO data
|
| 4 |
+
#
|
| 5 |
+
# Downloads data, fine-tunes ealvaradob/bert-finetuned-phishing
|
| 6 |
+
# 3 epochs, AdamW + linear warmup scheduler
|
| 7 |
+
# Saves to bert_weights/ with save_pretrained()
|
| 8 |
+
# Prints per-epoch: loss / precision / recall / F1
|
| 9 |
+
# ============================================================
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import List, Tuple
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
level=logging.INFO,
|
| 20 |
+
format="%(asctime)s | %(levelname)-7s | %(message)s",
|
| 21 |
+
)
|
| 22 |
+
logger = logging.getLogger("phishguard.bert_finetune")
|
| 23 |
+
|
| 24 |
+
BASE_DIR = Path(__file__).parent
|
| 25 |
+
BERT_WEIGHTS_DIR = BASE_DIR / "bert_weights"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main() -> None:
|
| 29 |
+
"""Fine-tune BERT on PhishTank + TRANCO URLs."""
|
| 30 |
+
print("=" * 60)
|
| 31 |
+
print("PhishGuard AI β BERT Fine-Tuning")
|
| 32 |
+
print("=" * 60)
|
| 33 |
+
|
| 34 |
+
# ββ Check dependencies βββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
try:
|
| 36 |
+
import torch
|
| 37 |
+
from torch.utils.data import DataLoader, Dataset
|
| 38 |
+
from torch.optim import AdamW
|
| 39 |
+
from transformers import (
|
| 40 |
+
AutoTokenizer,
|
| 41 |
+
AutoModelForSequenceClassification,
|
| 42 |
+
get_linear_schedule_with_warmup,
|
| 43 |
+
)
|
| 44 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
print(f"β Missing dependency: {e}")
|
| 47 |
+
print(" Run: pip install torch transformers scikit-learn")
|
| 48 |
+
sys.exit(1)
|
| 49 |
+
|
| 50 |
+
# ββ Download data ββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
from data_collector import download_phishtank, download_tranco, merge_datasets
|
| 52 |
+
|
| 53 |
+
print("\nπ₯ Downloading datasets...")
|
| 54 |
+
phish_urls = download_phishtank(max_urls=50)
|
| 55 |
+
legit_urls = download_tranco(n=50)
|
| 56 |
+
print(f" Phishing URLs: {len(phish_urls)}")
|
| 57 |
+
print(f" Legitimate URLs: {len(legit_urls)}")
|
| 58 |
+
|
| 59 |
+
train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls)
|
| 60 |
+
|
| 61 |
+
# ββ URL tokenization βββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
import re
|
| 63 |
+
_re_url_split = re.compile(r"[-./=?&_~%@:]+")
|
| 64 |
+
|
| 65 |
+
def tokenize_url(url: str) -> str:
|
| 66 |
+
text = url.replace("https://", "").replace("http://", "")
|
| 67 |
+
tokens = _re_url_split.split(text)
|
| 68 |
+
return " ".join(t for t in tokens if t)
|
| 69 |
+
|
| 70 |
+
# ββ Dataset class ββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
class PhishingURLDataset(Dataset):
|
| 72 |
+
def __init__(self, data: List[Tuple[str, int]], tokenizer, max_length: int = 512):
|
| 73 |
+
self.data = data
|
| 74 |
+
self.tokenizer = tokenizer
|
| 75 |
+
self.max_length = max_length
|
| 76 |
+
|
| 77 |
+
def __len__(self) -> int:
|
| 78 |
+
return len(self.data)
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, idx: int):
|
| 81 |
+
url, label = self.data[idx]
|
| 82 |
+
text = f"URL: {tokenize_url(url)}"
|
| 83 |
+
encoding = self.tokenizer(
|
| 84 |
+
text,
|
| 85 |
+
truncation=True,
|
| 86 |
+
padding="max_length",
|
| 87 |
+
max_length=self.max_length,
|
| 88 |
+
return_tensors="pt",
|
| 89 |
+
)
|
| 90 |
+
return {
|
| 91 |
+
"input_ids": encoding["input_ids"].squeeze(0),
|
| 92 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 93 |
+
"labels": torch.tensor(label, dtype=torch.long),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# ββ Load model βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 97 |
+
MODEL_NAME = "ealvaradob/bert-finetuned-phishing"
|
| 98 |
+
FALLBACK = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
|
| 99 |
+
|
| 100 |
+
print("\nπ€ Loading BERT model...")
|
| 101 |
+
tokenizer = None
|
| 102 |
+
model = None
|
| 103 |
+
for model_id in [MODEL_NAME, FALLBACK]:
|
| 104 |
+
try:
|
| 105 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 106 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 107 |
+
model_id, num_labels=2
|
| 108 |
+
)
|
| 109 |
+
print(f" β
Loaded: {model_id}")
|
| 110 |
+
break
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f" β οΈ {model_id} failed: {e}")
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
if model is None or tokenizer is None:
|
| 116 |
+
print("β Could not load any BERT model. Exiting.")
|
| 117 |
+
sys.exit(1)
|
| 118 |
+
|
| 119 |
+
# ββ Prepare data βββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 121 |
+
print(f" Device: {device}")
|
| 122 |
+
|
| 123 |
+
train_dataset = PhishingURLDataset(train_data, tokenizer)
|
| 124 |
+
val_dataset = PhishingURLDataset(val_data, tokenizer)
|
| 125 |
+
|
| 126 |
+
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
|
| 127 |
+
val_loader = DataLoader(val_dataset, batch_size=32)
|
| 128 |
+
|
| 129 |
+
model = model.to(device)
|
| 130 |
+
|
| 131 |
+
# ββ Optimizer + Scheduler ββββββββββββββββββββββββββββββββββββ
|
| 132 |
+
EPOCHS = 1
|
| 133 |
+
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
|
| 134 |
+
total_steps = len(train_loader) * EPOCHS
|
| 135 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 136 |
+
optimizer,
|
| 137 |
+
num_warmup_steps=total_steps // 10,
|
| 138 |
+
num_training_steps=total_steps,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# ββ Training Loop ββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
print(f"\nποΈ Training for {EPOCHS} epochs...")
|
| 143 |
+
print(f" Train batches: {len(train_loader)}")
|
| 144 |
+
print(f" Val batches: {len(val_loader)}")
|
| 145 |
+
|
| 146 |
+
best_f1 = 0.0
|
| 147 |
+
for epoch in range(1, EPOCHS + 1):
|
| 148 |
+
# Train
|
| 149 |
+
model.train()
|
| 150 |
+
total_loss = 0.0
|
| 151 |
+
train_preds = []
|
| 152 |
+
train_labels = []
|
| 153 |
+
|
| 154 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 155 |
+
input_ids = batch["input_ids"].to(device)
|
| 156 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 157 |
+
labels = batch["labels"].to(device)
|
| 158 |
+
|
| 159 |
+
optimizer.zero_grad()
|
| 160 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 161 |
+
loss = outputs.loss
|
| 162 |
+
loss.backward()
|
| 163 |
+
|
| 164 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 165 |
+
optimizer.step()
|
| 166 |
+
scheduler.step()
|
| 167 |
+
|
| 168 |
+
total_loss += loss.item()
|
| 169 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 170 |
+
train_preds.extend(preds.cpu().tolist())
|
| 171 |
+
train_labels.extend(labels.cpu().tolist())
|
| 172 |
+
|
| 173 |
+
if (batch_idx + 1) % 50 == 0:
|
| 174 |
+
print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")
|
| 175 |
+
|
| 176 |
+
avg_loss = total_loss / len(train_loader)
|
| 177 |
+
|
| 178 |
+
# Validate
|
| 179 |
+
model.eval()
|
| 180 |
+
val_preds = []
|
| 181 |
+
val_labels = []
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
for batch in val_loader:
|
| 184 |
+
input_ids = batch["input_ids"].to(device)
|
| 185 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 186 |
+
labels = batch["labels"].to(device)
|
| 187 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 188 |
+
preds = torch.argmax(outputs.logits, dim=1)
|
| 189 |
+
val_preds.extend(preds.cpu().tolist())
|
| 190 |
+
val_labels.extend(labels.cpu().tolist())
|
| 191 |
+
|
| 192 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 193 |
+
val_labels, val_preds, average="binary", zero_division=0
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
print(f"\n π Epoch {epoch}/{EPOCHS}:")
|
| 197 |
+
print(f" Loss: {avg_loss:.4f}")
|
| 198 |
+
print(f" Precision: {precision:.4f}")
|
| 199 |
+
print(f" Recall: {recall:.4f}")
|
| 200 |
+
print(f" F1 Score: {f1:.4f}")
|
| 201 |
+
|
| 202 |
+
# Save best model
|
| 203 |
+
if f1 > best_f1:
|
| 204 |
+
best_f1 = f1
|
| 205 |
+
BERT_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 206 |
+
model.save_pretrained(str(BERT_WEIGHTS_DIR))
|
| 207 |
+
tokenizer.save_pretrained(str(BERT_WEIGHTS_DIR))
|
| 208 |
+
print(f" β
New best model saved to {BERT_WEIGHTS_DIR}")
|
| 209 |
+
|
| 210 |
+
print(f"\nπ― Best F1: {best_f1:.4f}")
|
| 211 |
+
print(f"β
Fine-tuning complete. Weights saved to: {BERT_WEIGHTS_DIR}")
|
| 212 |
+
print("=" * 60)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
main()
|
cnn_inference.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - cnn/cnn_inference.py
|
| 3 |
+
# CNN inference wrapper for Tier 4 visual analysis.
|
| 4 |
+
# Supports: predict, hot-reload, incremental_update.
|
| 5 |
+
# ============================================================
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import io
|
| 10 |
+
import random
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("phishguard.cnn.inference")
|
| 19 |
+
|
| 20 |
+
CNN_DIR = Path(__file__).parent
|
| 21 |
+
BACKEND_DIR = CNN_DIR.parent
|
| 22 |
+
WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt"
|
| 23 |
+
REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class CNNInference:
|
| 27 |
+
"""CNN inference wrapper with hot-reload and incremental update."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, weights_path: Optional[Path] = None) -> None:
|
| 30 |
+
self._weights_path = weights_path or WEIGHTS_PATH
|
| 31 |
+
self._model = None
|
| 32 |
+
self._loaded = False
|
| 33 |
+
|
| 34 |
+
def load(self, weights_path: Optional[Path] = None) -> bool:
|
| 35 |
+
"""Load CNN model."""
|
| 36 |
+
from cnn_model import load_cnn
|
| 37 |
+
|
| 38 |
+
path = weights_path or self._weights_path
|
| 39 |
+
self._model = load_cnn(str(path) if path.exists() else None)
|
| 40 |
+
self._loaded = self._model is not None
|
| 41 |
+
return self._loaded
|
| 42 |
+
|
| 43 |
+
def predict(self, screenshot_bytes: bytes) -> float:
|
| 44 |
+
"""
|
| 45 |
+
Predict phishing probability from screenshot bytes.
|
| 46 |
+
Returns P_cnn β [0,1].
|
| 47 |
+
"""
|
| 48 |
+
if not self._loaded:
|
| 49 |
+
self.load()
|
| 50 |
+
|
| 51 |
+
if self._model is None:
|
| 52 |
+
return 0.5
|
| 53 |
+
|
| 54 |
+
from cnn_model import preprocess_screenshot
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
tensor = preprocess_screenshot(screenshot_bytes)
|
| 58 |
+
return self._model.predict_proba(tensor)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"CNN predict failed: {e}")
|
| 61 |
+
return 0.5
|
| 62 |
+
|
| 63 |
+
def reload(self, weights_path: Optional[Path] = None) -> bool:
|
| 64 |
+
"""Hot-reload model with new weights."""
|
| 65 |
+
from cnn_model import load_cnn
|
| 66 |
+
|
| 67 |
+
path = weights_path or self._weights_path
|
| 68 |
+
new_model = load_cnn(str(path))
|
| 69 |
+
if new_model is not None:
|
| 70 |
+
self._model = new_model
|
| 71 |
+
self._loaded = True
|
| 72 |
+
logger.info(f"CNN hot-reloaded from {path}")
|
| 73 |
+
return True
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
async def incremental_update(
|
| 77 |
+
self,
|
| 78 |
+
tier4_samples: List[Tuple[str, int]],
|
| 79 |
+
replay_buffer_path: Optional[Path] = None,
|
| 80 |
+
lr: float = 1e-4,
|
| 81 |
+
epochs: int = 3,
|
| 82 |
+
) -> Optional[float]:
|
| 83 |
+
"""
|
| 84 |
+
Incremental update on Tier 4 feedback samples.
|
| 85 |
+
Re-captures screenshots via Playwright, trains on them + replay buffer.
|
| 86 |
+
Returns accuracy_delta or None if no Tier 4 samples.
|
| 87 |
+
"""
|
| 88 |
+
if not tier4_samples:
|
| 89 |
+
logger.info("No Tier 4 samples β skipping CNN update")
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
if self._model is None:
|
| 93 |
+
logger.warning("CNN not loaded, cannot update")
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
import torch.nn as nn
|
| 98 |
+
from torch.optim import AdamW
|
| 99 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 100 |
+
import torchvision.transforms as T
|
| 101 |
+
|
| 102 |
+
device = torch.device("cpu")
|
| 103 |
+
model = self._model.to(device)
|
| 104 |
+
|
| 105 |
+
transform = T.Compose([
|
| 106 |
+
T.Resize((224, 224)),
|
| 107 |
+
T.ToTensor(),
|
| 108 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 109 |
+
])
|
| 110 |
+
|
| 111 |
+
# Try to capture screenshots for the new samples
|
| 112 |
+
tensors = []
|
| 113 |
+
labels = []
|
| 114 |
+
|
| 115 |
+
for url, label in tier4_samples:
|
| 116 |
+
try:
|
| 117 |
+
# Try to capture screenshot
|
| 118 |
+
screenshot_bytes = await self._capture_screenshot(url)
|
| 119 |
+
if screenshot_bytes:
|
| 120 |
+
img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB")
|
| 121 |
+
tensor = transform(img)
|
| 122 |
+
tensors.append(tensor)
|
| 123 |
+
labels.append(float(label))
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.warning(f"Screenshot capture failed for {url}: {e}")
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
# Load replay buffer (20% mix)
|
| 129 |
+
buf_path = replay_buffer_path or REPLAY_BUFFER_PATH
|
| 130 |
+
if buf_path.exists():
|
| 131 |
+
try:
|
| 132 |
+
buf_data = torch.load(buf_path, map_location="cpu", weights_only=False)
|
| 133 |
+
buf_paths = buf_data.get("paths", [])
|
| 134 |
+
buf_labels = buf_data.get("labels", [])
|
| 135 |
+
|
| 136 |
+
replay_count = max(1, len(buf_paths) // 5)
|
| 137 |
+
indices = random.sample(range(len(buf_paths)), min(replay_count, len(buf_paths)))
|
| 138 |
+
|
| 139 |
+
for idx in indices:
|
| 140 |
+
try:
|
| 141 |
+
img = Image.open(buf_paths[idx]).convert("RGB")
|
| 142 |
+
tensor = transform(img)
|
| 143 |
+
tensors.append(tensor)
|
| 144 |
+
labels.append(float(buf_labels[idx]))
|
| 145 |
+
except Exception:
|
| 146 |
+
continue
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.warning(f"CNN replay buffer load failed: {e}")
|
| 149 |
+
|
| 150 |
+
if len(tensors) < 5:
|
| 151 |
+
logger.warning(f"Too few CNN samples ({len(tensors)}), skipping update")
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
# Stack and create dataset
|
| 155 |
+
x_data = torch.stack(tensors)
|
| 156 |
+
y_data = torch.tensor(labels, dtype=torch.float)
|
| 157 |
+
dataset = TensorDataset(x_data, y_data)
|
| 158 |
+
loader = DataLoader(dataset, batch_size=8, shuffle=True)
|
| 159 |
+
|
| 160 |
+
# Pre-update accuracy
|
| 161 |
+
model.eval()
|
| 162 |
+
pre_correct = 0
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
for bx, by in loader:
|
| 165 |
+
bx, by = bx.to(device), by.to(device)
|
| 166 |
+
out = model(bx).squeeze()
|
| 167 |
+
preds = (out >= 0.5).float()
|
| 168 |
+
pre_correct += (preds == by).sum().item()
|
| 169 |
+
pre_acc = pre_correct / len(dataset)
|
| 170 |
+
|
| 171 |
+
# Train (head only β backbone stays frozen)
|
| 172 |
+
head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad]
|
| 173 |
+
optimizer = AdamW(head_params, lr=lr)
|
| 174 |
+
loss_fn = nn.BCELoss()
|
| 175 |
+
|
| 176 |
+
model.train()
|
| 177 |
+
for epoch in range(epochs):
|
| 178 |
+
total_loss = 0.0
|
| 179 |
+
for bx, by in loader:
|
| 180 |
+
bx, by = bx.to(device), by.to(device)
|
| 181 |
+
optimizer.zero_grad()
|
| 182 |
+
out = model(bx).squeeze()
|
| 183 |
+
loss = loss_fn(out, by)
|
| 184 |
+
loss.backward()
|
| 185 |
+
optimizer.step()
|
| 186 |
+
total_loss += loss.item()
|
| 187 |
+
logger.info(f"CNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}")
|
| 188 |
+
|
| 189 |
+
# Post-update accuracy
|
| 190 |
+
model.eval()
|
| 191 |
+
post_correct = 0
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
for bx, by in loader:
|
| 194 |
+
bx, by = bx.to(device), by.to(device)
|
| 195 |
+
out = model(bx).squeeze()
|
| 196 |
+
preds = (out >= 0.5).float()
|
| 197 |
+
post_correct += (preds == by).sum().item()
|
| 198 |
+
post_acc = post_correct / len(dataset)
|
| 199 |
+
|
| 200 |
+
delta = post_acc - pre_acc
|
| 201 |
+
self._model = model
|
| 202 |
+
|
| 203 |
+
# Save weights
|
| 204 |
+
torch.save(model.state_dict(), self._weights_path)
|
| 205 |
+
logger.info(f"CNN incremental: {pre_acc:.4f} β {post_acc:.4f} (Ξ={delta:+.4f})")
|
| 206 |
+
|
| 207 |
+
return round(delta, 4)
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
logger.error(f"CNN incremental update failed: {e}")
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
async def _capture_screenshot(self, url: str) -> Optional[bytes]:
|
| 214 |
+
"""Capture a screenshot of a URL using Playwright."""
|
| 215 |
+
try:
|
| 216 |
+
from playwright.async_api import async_playwright
|
| 217 |
+
|
| 218 |
+
async with async_playwright() as p:
|
| 219 |
+
browser = await p.chromium.launch(headless=True)
|
| 220 |
+
page = await browser.new_page(viewport={"width": 1280, "height": 800})
|
| 221 |
+
|
| 222 |
+
# Block heavy resources
|
| 223 |
+
await page.route("**/*.{png,jpg,jpeg,gif,svg,mp4,webm,ogg,woff,woff2,ttf,eot}",
|
| 224 |
+
lambda route: route.abort())
|
| 225 |
+
|
| 226 |
+
await page.goto(url, wait_until="domcontentloaded", timeout=10000)
|
| 227 |
+
screenshot = await page.screenshot(type="png")
|
| 228 |
+
await browser.close()
|
| 229 |
+
return screenshot
|
| 230 |
+
|
| 231 |
+
except Exception as e:
|
| 232 |
+
logger.warning(f"Screenshot capture failed: {e}")
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def is_loaded(self) -> bool:
|
| 237 |
+
return self._loaded
|
cnn_model.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - cnn/cnn_model.py
|
| 3 |
+
# ResNet50 visual classifier for phishing screenshot detection.
|
| 4 |
+
#
|
| 5 |
+
# Architecture (from spec):
|
| 6 |
+
# Backbone: ResNet50 fully frozen
|
| 7 |
+
# Custom head: Linear(2048β512) β ReLU β Dropout(0.5) β
|
| 8 |
+
# Linear(512β1) β Sigmoid
|
| 9 |
+
# Input: 224Γ224 screenshot tensor
|
| 10 |
+
# Output: P_cnn β [0,1]
|
| 11 |
+
# ============================================================
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import io
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torchvision.models as models
|
| 22 |
+
import torchvision.transforms as T
|
| 23 |
+
from PIL import Image
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("phishguard.cnn.model")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PhishCNN(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
ResNet50 with frozen backbone and custom 2-layer binary classification head.
|
| 31 |
+
Output: P_cnn β [0,1] via sigmoid.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, pretrained: bool = True) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
# Load pretrained ResNet50 backbone
|
| 38 |
+
if pretrained:
|
| 39 |
+
self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
| 40 |
+
else:
|
| 41 |
+
self.backbone = models.resnet50(weights=None)
|
| 42 |
+
|
| 43 |
+
# Freeze entire backbone
|
| 44 |
+
for param in self.backbone.parameters():
|
| 45 |
+
param.requires_grad = False
|
| 46 |
+
|
| 47 |
+
# Replace fc with custom head: 2048 β 512 β 1 β sigmoid
|
| 48 |
+
in_features = self.backbone.fc.in_features # 2048
|
| 49 |
+
self.backbone.fc = nn.Sequential(
|
| 50 |
+
nn.Linear(in_features, 512),
|
| 51 |
+
nn.ReLU(),
|
| 52 |
+
nn.Dropout(0.5),
|
| 53 |
+
nn.Linear(512, 1),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Ensure custom head is trainable
|
| 57 |
+
for param in self.backbone.fc.parameters():
|
| 58 |
+
param.requires_grad = True
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Forward pass.
|
| 63 |
+
Input: (batch, 3, 224, 224)
|
| 64 |
+
Output: (batch, 1) probabilities in [0, 1]
|
| 65 |
+
"""
|
| 66 |
+
logits = self.backbone(x)
|
| 67 |
+
return torch.sigmoid(logits)
|
| 68 |
+
|
| 69 |
+
def predict_proba(self, x: torch.Tensor) -> float:
|
| 70 |
+
"""Return P_cnn β [0,1] β probability of phishing."""
|
| 71 |
+
self.eval()
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
output = self.forward(x)
|
| 74 |
+
return output.squeeze().item()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ββ Preprocessing pipeline (matches ImageNet normalization) ββββββββββ
|
| 78 |
+
TRANSFORM = T.Compose([
|
| 79 |
+
T.Resize((224, 224)),
|
| 80 |
+
T.ToTensor(),
|
| 81 |
+
T.Normalize(
|
| 82 |
+
mean=[0.485, 0.456, 0.406], # ImageNet mean
|
| 83 |
+
std=[0.229, 0.224, 0.225], # ImageNet std
|
| 84 |
+
),
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
# Training augmentation transforms
|
| 88 |
+
TRAIN_TRANSFORM = T.Compose([
|
| 89 |
+
T.Resize((224, 224)),
|
| 90 |
+
T.RandomHorizontalFlip(),
|
| 91 |
+
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
| 92 |
+
T.RandomRotation(5),
|
| 93 |
+
T.ToTensor(),
|
| 94 |
+
T.Normalize(
|
| 95 |
+
mean=[0.485, 0.456, 0.406],
|
| 96 |
+
std=[0.229, 0.224, 0.225],
|
| 97 |
+
),
|
| 98 |
+
])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def preprocess_screenshot(screenshot_bytes: bytes) -> torch.Tensor:
|
| 102 |
+
"""Convert raw screenshot bytes β model-ready tensor [1, 3, 224, 224]."""
|
| 103 |
+
img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB")
|
| 104 |
+
return TRANSFORM(img).unsqueeze(0)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def load_cnn(weights_path: Optional[str] = None) -> PhishCNN:
|
| 108 |
+
"""Load CNN model with optional trained weights."""
|
| 109 |
+
model = PhishCNN(pretrained=True)
|
| 110 |
+
|
| 111 |
+
if weights_path:
|
| 112 |
+
try:
|
| 113 |
+
state = torch.load(weights_path, map_location="cpu", weights_only=True)
|
| 114 |
+
model.load_state_dict(state)
|
| 115 |
+
logger.info(f"CNN weights loaded from {weights_path}")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.warning(f"Could not load CNN weights: {e}")
|
| 118 |
+
logger.info("Using ImageNet features only (baseline)")
|
| 119 |
+
|
| 120 |
+
model.eval()
|
| 121 |
+
return model
|
config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "ealvaradob/bert-finetuned-phishing",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"classifier_dropout": null,
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"id2label": {
|
| 13 |
+
"0": "benign",
|
| 14 |
+
"1": "phishing"
|
| 15 |
+
},
|
| 16 |
+
"initializer_range": 0.02,
|
| 17 |
+
"intermediate_size": 4096,
|
| 18 |
+
"label2id": {
|
| 19 |
+
"benign": 0,
|
| 20 |
+
"phishing": 1
|
| 21 |
+
},
|
| 22 |
+
"layer_norm_eps": 1e-12,
|
| 23 |
+
"max_position_embeddings": 512,
|
| 24 |
+
"model_type": "bert",
|
| 25 |
+
"num_attention_heads": 16,
|
| 26 |
+
"num_hidden_layers": 24,
|
| 27 |
+
"pad_token_id": 0,
|
| 28 |
+
"position_embedding_type": "absolute",
|
| 29 |
+
"problem_type": "single_label_classification",
|
| 30 |
+
"torch_dtype": "float32",
|
| 31 |
+
"transformers_version": "4.40.0",
|
| 32 |
+
"type_vocab_size": 2,
|
| 33 |
+
"use_cache": true,
|
| 34 |
+
"vocab_size": 30522
|
| 35 |
+
}
|
content.js
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// ============================================================
|
| 2 |
+
// PhishGuard AI - content.js
|
| 3 |
+
// Content script: runs inside every page.
|
| 4 |
+
// Detects phishing signals and injects feedback banner.
|
| 5 |
+
// ============================================================
|
| 6 |
+
|
| 7 |
+
(function() {
|
| 8 |
+
"use strict";
|
| 9 |
+
|
| 10 |
+
// ββ Page Signal Detection ββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
function detectPageSignals() {
|
| 12 |
+
const signals = [];
|
| 13 |
+
const title = document.title || "";
|
| 14 |
+
const bodyText = (document.body?.innerText || "").substring(0, 2000).toLowerCase();
|
| 15 |
+
const url = window.location.href.toLowerCase();
|
| 16 |
+
|
| 17 |
+
// 1. Password form posting to external domain
|
| 18 |
+
const forms = document.querySelectorAll("form");
|
| 19 |
+
forms.forEach(form => {
|
| 20 |
+
const hasPassword = form.querySelector('input[type="password"]');
|
| 21 |
+
const action = (form.getAttribute("action") || "").toLowerCase();
|
| 22 |
+
if (hasPassword && action.startsWith("http") && !action.includes(window.location.hostname)) {
|
| 23 |
+
signals.push("password_form_external_action");
|
| 24 |
+
}
|
| 25 |
+
});
|
| 26 |
+
|
| 27 |
+
// 2. Brand name in title mismatching hostname
|
| 28 |
+
const brands = ["paypal","google","apple","microsoft","amazon","netflix",
|
| 29 |
+
"facebook","instagram","chase","wellsfargo","bankofamerica"];
|
| 30 |
+
const hostname = window.location.hostname.toLowerCase();
|
| 31 |
+
for (const brand of brands) {
|
| 32 |
+
if (title.toLowerCase().includes(brand) && !hostname.includes(brand)) {
|
| 33 |
+
signals.push(`brand_mismatch:${brand}`);
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// 3. Urgency language
|
| 38 |
+
const urgencyPhrases = [
|
| 39 |
+
"your account has been", "verify immediately", "suspended",
|
| 40 |
+
"unusual activity", "click here now", "act now",
|
| 41 |
+
"confirm your identity", "limited time", "expires soon"
|
| 42 |
+
];
|
| 43 |
+
for (const phrase of urgencyPhrases) {
|
| 44 |
+
if (bodyText.includes(phrase)) {
|
| 45 |
+
signals.push("urgency_language");
|
| 46 |
+
break;
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// 4. Hidden iframes
|
| 51 |
+
const iframes = document.querySelectorAll("iframe");
|
| 52 |
+
iframes.forEach(iframe => {
|
| 53 |
+
const style = window.getComputedStyle(iframe);
|
| 54 |
+
const w = parseInt(style.width) || iframe.width;
|
| 55 |
+
const h = parseInt(style.height) || iframe.height;
|
| 56 |
+
if (style.display === "none" || style.visibility === "hidden" ||
|
| 57 |
+
(w <= 1 && h <= 1)) {
|
| 58 |
+
signals.push("hidden_iframe");
|
| 59 |
+
}
|
| 60 |
+
});
|
| 61 |
+
|
| 62 |
+
return {
|
| 63 |
+
title,
|
| 64 |
+
snippet: bodyText.substring(0, 500),
|
| 65 |
+
signals,
|
| 66 |
+
};
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// Send signals to background.js
|
| 70 |
+
try {
|
| 71 |
+
const pageData = detectPageSignals();
|
| 72 |
+
chrome.runtime.sendMessage({
|
| 73 |
+
type: "page_signals",
|
| 74 |
+
url: window.location.href,
|
| 75 |
+
...pageData,
|
| 76 |
+
});
|
| 77 |
+
} catch (e) {
|
| 78 |
+
// Extension context may be invalidated
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// ββ Feedback Banner Injection ββββββββββββββββββββββββββββββββββ
|
| 82 |
+
// Listen for messages from background.js to inject banner
|
| 83 |
+
chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => {
|
| 84 |
+
if (msg.type === "inject_feedback_banner") {
|
| 85 |
+
injectFeedbackBanner(msg.verdict, msg.confidence, msg.urlHash, msg.tier);
|
| 86 |
+
sendResponse({ success: true });
|
| 87 |
+
}
|
| 88 |
+
});
|
| 89 |
+
|
| 90 |
+
function injectFeedbackBanner(verdict, confidence, urlHash, tier) {
|
| 91 |
+
// Don't inject if already present
|
| 92 |
+
if (document.getElementById("phishguard-feedback-banner")) return;
|
| 93 |
+
|
| 94 |
+
const isPhishing = verdict === "phishing";
|
| 95 |
+
const confPct = Math.round(confidence * 100);
|
| 96 |
+
const tierText = `Tier ${tier}`;
|
| 97 |
+
|
| 98 |
+
const banner = document.createElement("div");
|
| 99 |
+
banner.id = "phishguard-feedback-banner";
|
| 100 |
+
banner.style.cssText = `
|
| 101 |
+
position: fixed; top: 0; left: 0; right: 0; z-index: 2147483647;
|
| 102 |
+
background: ${isPhishing ? "linear-gradient(135deg, #1a0000, #3a0000)" : "linear-gradient(135deg, #001a00, #003a00)"};
|
| 103 |
+
color: white; padding: 10px 20px;
|
| 104 |
+
display: flex; align-items: center; gap: 12px;
|
| 105 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
|
| 106 |
+
font-size: 14px; box-shadow: 0 4px 20px rgba(0,0,0,0.5);
|
| 107 |
+
border-bottom: 2px solid ${isPhishing ? "#ef4444" : "#22c55e"};
|
| 108 |
+
`;
|
| 109 |
+
|
| 110 |
+
const icon = isPhishing ? "π‘οΈ" : "β
";
|
| 111 |
+
const statusText = isPhishing ? "PhishGuard flagged this page" : "PhishGuard: Page looks safe";
|
| 112 |
+
|
| 113 |
+
banner.innerHTML = `
|
| 114 |
+
<span style="font-size: 18px">${icon}</span>
|
| 115 |
+
<span style="flex: 1">${statusText} Β· ${confPct}% Β· ${tierText}</span>
|
| 116 |
+
<button id="pg-correct" style="
|
| 117 |
+
background: rgba(34,197,94,0.2); border: 1px solid #22c55e; color: #22c55e;
|
| 118 |
+
padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 13px;
|
| 119 |
+
font-weight: 600; transition: all 0.2s;
|
| 120 |
+
">π Correct</button>
|
| 121 |
+
<button id="pg-wrong" style="
|
| 122 |
+
background: rgba(239,68,68,0.2); border: 1px solid #ef4444; color: #ef4444;
|
| 123 |
+
padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 13px;
|
| 124 |
+
font-weight: 600; transition: all 0.2s;
|
| 125 |
+
">π Wrong</button>
|
| 126 |
+
${isPhishing ? `<button id="pg-proceed" style="
|
| 127 |
+
background: transparent; border: 1px solid rgba(255,255,255,0.2); color: #999;
|
| 128 |
+
padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 12px;
|
| 129 |
+
transition: all 0.2s;
|
| 130 |
+
">Proceed Anyway</button>` : ""}
|
| 131 |
+
<button id="pg-close" style="
|
| 132 |
+
background: none; border: none; color: #666; cursor: pointer;
|
| 133 |
+
font-size: 18px; padding: 0 4px;
|
| 134 |
+
">Γ</button>
|
| 135 |
+
`;
|
| 136 |
+
|
| 137 |
+
document.body.prepend(banner);
|
| 138 |
+
document.body.style.marginTop = (banner.offsetHeight) + "px";
|
| 139 |
+
|
| 140 |
+
// Button handlers
|
| 141 |
+
document.getElementById("pg-correct")?.addEventListener("click", () => {
|
| 142 |
+
submitBannerFeedback(urlHash, "correct", banner);
|
| 143 |
+
});
|
| 144 |
+
|
| 145 |
+
document.getElementById("pg-wrong")?.addEventListener("click", () => {
|
| 146 |
+
submitBannerFeedback(urlHash, "incorrect", banner);
|
| 147 |
+
});
|
| 148 |
+
|
| 149 |
+
document.getElementById("pg-proceed")?.addEventListener("click", () => {
|
| 150 |
+
chrome.runtime.sendMessage({ type: "whitelist_url", url: window.location.href });
|
| 151 |
+
removeBanner(banner);
|
| 152 |
+
});
|
| 153 |
+
|
| 154 |
+
document.getElementById("pg-close")?.addEventListener("click", () => {
|
| 155 |
+
removeBanner(banner);
|
| 156 |
+
});
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
function submitBannerFeedback(urlHash, feedback, banner) {
|
| 160 |
+
chrome.runtime.sendMessage({
|
| 161 |
+
type: "submit_feedback",
|
| 162 |
+
url_hash: urlHash,
|
| 163 |
+
feedback: feedback,
|
| 164 |
+
}, (response) => {
|
| 165 |
+
if (response?.success) {
|
| 166 |
+
banner.innerHTML = `
|
| 167 |
+
<span style="font-size: 18px">β
</span>
|
| 168 |
+
<span style="flex: 1; color: #22c55e">Thanks! Your feedback helps improve PhishGuard</span>
|
| 169 |
+
`;
|
| 170 |
+
setTimeout(() => removeBanner(banner), 3000);
|
| 171 |
+
}
|
| 172 |
+
});
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
function removeBanner(banner) {
|
| 176 |
+
document.body.style.marginTop = "";
|
| 177 |
+
banner.remove();
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
})();
|
data_collector.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - data_collector.py
|
| 3 |
+
# Downloads all training data from public HTTP endpoints.
|
| 4 |
+
# No API keys required.
|
| 5 |
+
#
|
| 6 |
+
# Datasets:
|
| 7 |
+
# 1. PhishTank (bz2 JSON β phishing URLs)
|
| 8 |
+
# 2. TRANCO Top-10K (zip CSV β legitimate domains)
|
| 9 |
+
# 3. Kaggle GitHub mirror (CSV β pre-extracted features)
|
| 10 |
+
# ============================================================
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import bz2
|
| 15 |
+
import csv
|
| 16 |
+
import io
|
| 17 |
+
import json
|
| 18 |
+
import zipfile
|
| 19 |
+
import hashlib
|
| 20 |
+
import logging
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Tuple, Optional
|
| 23 |
+
|
| 24 |
+
import requests
|
| 25 |
+
import pandas as pd
|
| 26 |
+
from sklearn.model_selection import train_test_split
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger("phishguard.data_collector")
|
| 29 |
+
|
| 30 |
+
# ββ Data directory ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
DATA_DIR = Path(__file__).parent / "data"
|
| 32 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
# ββ Public URLs (no API keys) ββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
PHISHTANK_URL = "http://data.phishtank.com/data/online-valid.json.bz2"
|
| 36 |
+
TRANCO_URL = "https://tranco-list.eu/top-1m.csv.zip"
|
| 37 |
+
KAGGLE_PRIMARY = "https://raw.githubusercontent.com/GregaVrbancic/Phishing-Dataset/master/dataset_full.csv"
|
| 38 |
+
KAGGLE_BACKUP = "https://raw.githubusercontent.com/datasets/phishing-websites/master/data.csv"
|
| 39 |
+
|
| 40 |
+
HEADERS = {
|
| 41 |
+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
| 42 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 43 |
+
"Chrome/120.0.0.0 Safari/537.36"
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def download_phishtank(max_urls: int = 30000) -> List[str]:
|
| 48 |
+
"""
|
| 49 |
+
Download phishing URLs from PhishTank public feed.
|
| 50 |
+
Fetches bz2 β decompresses β parses JSON β filters verified+online.
|
| 51 |
+
|
| 52 |
+
Returns list of verified phishing URLs (up to max_urls).
|
| 53 |
+
"""
|
| 54 |
+
logger.info("Downloading PhishTank data...")
|
| 55 |
+
phish_cache = DATA_DIR / "phishing_urls.txt"
|
| 56 |
+
|
| 57 |
+
# Use cache if recent
|
| 58 |
+
if phish_cache.exists() and phish_cache.stat().st_size > 1000:
|
| 59 |
+
urls = phish_cache.read_text().strip().splitlines()
|
| 60 |
+
if len(urls) >= 100:
|
| 61 |
+
logger.info(f"Using cached PhishTank data: {len(urls)} URLs")
|
| 62 |
+
return urls[:max_urls]
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
resp = requests.get(PHISHTANK_URL, headers=HEADERS, timeout=120, stream=True)
|
| 66 |
+
resp.raise_for_status()
|
| 67 |
+
|
| 68 |
+
# Decompress bz2
|
| 69 |
+
raw_data = bz2.decompress(resp.content)
|
| 70 |
+
records = json.loads(raw_data)
|
| 71 |
+
|
| 72 |
+
# Filter: verified=True AND online (verification_time present)
|
| 73 |
+
urls: List[str] = []
|
| 74 |
+
for record in records:
|
| 75 |
+
if not isinstance(record, dict):
|
| 76 |
+
continue
|
| 77 |
+
url = record.get("url", "").strip()
|
| 78 |
+
verified = record.get("verified", "no")
|
| 79 |
+
online = record.get("online", "no")
|
| 80 |
+
|
| 81 |
+
is_verified = verified in (True, "yes", "true", "True", "1", 1)
|
| 82 |
+
is_online = online in (True, "yes", "true", "True", "1", 1)
|
| 83 |
+
|
| 84 |
+
if url and is_verified and is_online:
|
| 85 |
+
urls.append(url)
|
| 86 |
+
if len(urls) >= max_urls:
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
logger.info(f"PhishTank: {len(urls)} verified+online URLs extracted")
|
| 90 |
+
|
| 91 |
+
# Cache to disk
|
| 92 |
+
phish_cache.write_text("\n".join(urls))
|
| 93 |
+
return urls
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.warning(f"PhishTank download failed: {e}")
|
| 97 |
+
# Fallback: try to use cached data
|
| 98 |
+
if phish_cache.exists():
|
| 99 |
+
urls = phish_cache.read_text().strip().splitlines()
|
| 100 |
+
logger.info(f"Using fallback cached data: {len(urls)} URLs")
|
| 101 |
+
return urls[:max_urls]
|
| 102 |
+
|
| 103 |
+
# Generate synthetic phishing-like URLs for training
|
| 104 |
+
logger.warning("Generating synthetic phishing URLs as fallback")
|
| 105 |
+
return _generate_synthetic_phishing(500)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _generate_synthetic_phishing(count: int) -> List[str]:
|
| 109 |
+
"""Generate synthetic phishing URLs for training when real data unavailable."""
|
| 110 |
+
import random
|
| 111 |
+
brands = ["paypal", "google", "apple", "microsoft", "amazon", "netflix",
|
| 112 |
+
"facebook", "chase", "wellsfargo", "bankofamerica"]
|
| 113 |
+
tlds = [".xyz", ".tk", ".ml", ".ga", ".cf", ".gq", ".pw", ".top", ".click"]
|
| 114 |
+
keywords = ["login", "verify", "secure", "update", "account", "signin",
|
| 115 |
+
"reset", "confirm", "suspend", "banking", "alert", "password"]
|
| 116 |
+
urls: List[str] = []
|
| 117 |
+
for _ in range(count):
|
| 118 |
+
brand = random.choice(brands)
|
| 119 |
+
tld = random.choice(tlds)
|
| 120 |
+
kw = random.choice(keywords)
|
| 121 |
+
sep = random.choice(["-", ".", ""])
|
| 122 |
+
prefix = random.choice(["http://", "https://"])
|
| 123 |
+
sub = random.choice(["", "www.", "secure.", "login.", "m."])
|
| 124 |
+
urls.append(f"{prefix}{sub}{brand}{sep}{kw}{tld}/{kw}/index.html")
|
| 125 |
+
return urls
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def download_tranco(n: int = 10000) -> List[str]:
|
| 129 |
+
"""
|
| 130 |
+
Download TRANCO Top-1M list, return top-N domains as https:// URLs.
|
| 131 |
+
|
| 132 |
+
Fetches zip β extracts CSV β takes column 2 (domain) β top N rows.
|
| 133 |
+
"""
|
| 134 |
+
logger.info(f"Downloading TRANCO top-{n} domains...")
|
| 135 |
+
legit_cache = DATA_DIR / "legitimate_urls.txt"
|
| 136 |
+
|
| 137 |
+
# Use cache if present
|
| 138 |
+
if legit_cache.exists() and legit_cache.stat().st_size > 1000:
|
| 139 |
+
urls = legit_cache.read_text().strip().splitlines()
|
| 140 |
+
if len(urls) >= min(n, 100):
|
| 141 |
+
logger.info(f"Using cached TRANCO data: {len(urls)} domains")
|
| 142 |
+
return urls[:n]
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
resp = requests.get(TRANCO_URL, headers=HEADERS, timeout=60)
|
| 146 |
+
resp.raise_for_status()
|
| 147 |
+
|
| 148 |
+
# Extract CSV from zip
|
| 149 |
+
with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
|
| 150 |
+
csv_name = zf.namelist()[0]
|
| 151 |
+
csv_data = zf.read(csv_name).decode("utf-8")
|
| 152 |
+
|
| 153 |
+
# Parse: format is "rank,domain" per line
|
| 154 |
+
urls: List[str] = []
|
| 155 |
+
for line in csv_data.strip().splitlines():
|
| 156 |
+
parts = line.split(",")
|
| 157 |
+
if len(parts) >= 2:
|
| 158 |
+
domain = parts[1].strip()
|
| 159 |
+
if domain:
|
| 160 |
+
urls.append(f"https://{domain}")
|
| 161 |
+
if len(urls) >= n:
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
logger.info(f"TRANCO: {len(urls)} legitimate domains extracted")
|
| 165 |
+
|
| 166 |
+
# Cache to disk
|
| 167 |
+
legit_cache.write_text("\n".join(urls))
|
| 168 |
+
return urls
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logger.warning(f"TRANCO download failed: {e}")
|
| 172 |
+
# Fallback: use cached data or generate synthetic
|
| 173 |
+
if legit_cache.exists():
|
| 174 |
+
urls = legit_cache.read_text().strip().splitlines()
|
| 175 |
+
return urls[:n]
|
| 176 |
+
|
| 177 |
+
logger.warning("Generating synthetic legitimate URLs as fallback")
|
| 178 |
+
return _generate_synthetic_legitimate(n)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _generate_synthetic_legitimate(count: int) -> List[str]:
|
| 182 |
+
"""Generate legitimate-looking URLs as fallback."""
|
| 183 |
+
top_domains = [
|
| 184 |
+
"google.com", "youtube.com", "facebook.com", "amazon.com",
|
| 185 |
+
"wikipedia.org", "twitter.com", "instagram.com", "linkedin.com",
|
| 186 |
+
"microsoft.com", "apple.com", "github.com", "stackoverflow.com",
|
| 187 |
+
"reddit.com", "netflix.com", "paypal.com", "yahoo.com", "bing.com",
|
| 188 |
+
"adobe.com", "dropbox.com", "zoom.us", "slack.com", "spotify.com",
|
| 189 |
+
"twitch.tv", "ebay.com", "walmart.com", "target.com", "cnn.com",
|
| 190 |
+
"bbc.com", "nytimes.com", "medium.com",
|
| 191 |
+
]
|
| 192 |
+
urls = [f"https://{d}" for d in top_domains]
|
| 193 |
+
# Pad with numbered subpages
|
| 194 |
+
while len(urls) < count:
|
| 195 |
+
d = top_domains[len(urls) % len(top_domains)]
|
| 196 |
+
urls.append(f"https://{d}/page/{len(urls)}")
|
| 197 |
+
return urls[:count]
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def download_kaggle_mirror() -> pd.DataFrame:
|
| 201 |
+
"""
|
| 202 |
+
Download pre-extracted URL features from Kaggle GitHub mirror.
|
| 203 |
+
Falls back to backup URL if primary fails.
|
| 204 |
+
|
| 205 |
+
Returns DataFrame with features and CLASS_LABEL column.
|
| 206 |
+
"""
|
| 207 |
+
logger.info("Downloading Kaggle URL features dataset...")
|
| 208 |
+
kaggle_cache = DATA_DIR / "kaggle_features.csv"
|
| 209 |
+
|
| 210 |
+
if kaggle_cache.exists() and kaggle_cache.stat().st_size > 1000:
|
| 211 |
+
logger.info("Using cached Kaggle features")
|
| 212 |
+
return pd.read_csv(kaggle_cache)
|
| 213 |
+
|
| 214 |
+
for url in [KAGGLE_PRIMARY, KAGGLE_BACKUP]:
|
| 215 |
+
try:
|
| 216 |
+
resp = requests.get(url, headers=HEADERS, timeout=60)
|
| 217 |
+
resp.raise_for_status()
|
| 218 |
+
df = pd.read_csv(io.StringIO(resp.text))
|
| 219 |
+
|
| 220 |
+
# Standardize label column name
|
| 221 |
+
label_candidates = ["CLASS_LABEL", "class_label", "Result", "result", "label"]
|
| 222 |
+
for col in label_candidates:
|
| 223 |
+
if col in df.columns:
|
| 224 |
+
df = df.rename(columns={col: "CLASS_LABEL"})
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
if "CLASS_LABEL" not in df.columns:
|
| 228 |
+
# Try last column
|
| 229 |
+
df = df.rename(columns={df.columns[-1]: "CLASS_LABEL"})
|
| 230 |
+
|
| 231 |
+
# Normalize labels to 0/1
|
| 232 |
+
if df["CLASS_LABEL"].dtype == object:
|
| 233 |
+
df["CLASS_LABEL"] = df["CLASS_LABEL"].map(
|
| 234 |
+
{"legitimate": 0, "phishing": 1, "safe": 0}
|
| 235 |
+
).fillna(0).astype(int)
|
| 236 |
+
else:
|
| 237 |
+
# Handle -1 as legitimate (common in some datasets)
|
| 238 |
+
df["CLASS_LABEL"] = df["CLASS_LABEL"].apply(
|
| 239 |
+
lambda x: 0 if x <= 0 else 1
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Cache
|
| 243 |
+
df.to_csv(kaggle_cache, index=False)
|
| 244 |
+
logger.info(f"Kaggle features: {len(df)} rows, {len(df.columns)} columns")
|
| 245 |
+
return df
|
| 246 |
+
|
| 247 |
+
except Exception as e:
|
| 248 |
+
logger.warning(f"Kaggle mirror {url} failed: {e}")
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
logger.error("All Kaggle mirrors failed")
|
| 252 |
+
return pd.DataFrame()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def merge_datasets(
|
| 256 |
+
phish_urls: List[str],
|
| 257 |
+
legit_urls: List[str],
|
| 258 |
+
test_size: float = 0.15,
|
| 259 |
+
val_size: float = 0.15,
|
| 260 |
+
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]], List[Tuple[str, int]]]:
|
| 261 |
+
"""
|
| 262 |
+
Merge phishing + legitimate URLs, return stratified 70/15/15 split.
|
| 263 |
+
|
| 264 |
+
Returns (train, val, test) where each is List[(url, label)].
|
| 265 |
+
Label: 1 = phishing, 0 = legitimate.
|
| 266 |
+
"""
|
| 267 |
+
# Deduplicate
|
| 268 |
+
phish_set = set(phish_urls)
|
| 269 |
+
legit_set = set(legit_urls) - phish_set # Ensure no URL in both sets
|
| 270 |
+
|
| 271 |
+
all_data = [(url, 1) for url in phish_set] + [(url, 0) for url in legit_set]
|
| 272 |
+
urls = [d[0] for d in all_data]
|
| 273 |
+
labels = [d[1] for d in all_data]
|
| 274 |
+
|
| 275 |
+
# First split: train+val vs test
|
| 276 |
+
train_val_urls, test_urls, train_val_labels, test_labels = train_test_split(
|
| 277 |
+
urls, labels,
|
| 278 |
+
test_size=test_size,
|
| 279 |
+
stratify=labels,
|
| 280 |
+
random_state=42,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Second split: train vs val
|
| 284 |
+
relative_val = val_size / (1 - test_size)
|
| 285 |
+
train_urls, val_urls, train_labels, val_labels = train_test_split(
|
| 286 |
+
train_val_urls, train_val_labels,
|
| 287 |
+
test_size=relative_val,
|
| 288 |
+
stratify=train_val_labels,
|
| 289 |
+
random_state=42,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
train = list(zip(train_urls, train_labels))
|
| 293 |
+
val = list(zip(val_urls, val_labels))
|
| 294 |
+
test = list(zip(test_urls, test_labels))
|
| 295 |
+
|
| 296 |
+
logger.info(f"Dataset split: train={len(train)}, val={len(val)}, test={len(test)}")
|
| 297 |
+
return train, val, test
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def save_url_lists(
|
| 301 |
+
phish_urls: List[str],
|
| 302 |
+
legit_urls: List[str],
|
| 303 |
+
phish_path: Optional[Path] = None,
|
| 304 |
+
legit_path: Optional[Path] = None,
|
| 305 |
+
) -> None:
|
| 306 |
+
"""Save URL lists to text files."""
|
| 307 |
+
phish_path = phish_path or DATA_DIR / "phishing_urls.txt"
|
| 308 |
+
legit_path = legit_path or DATA_DIR / "legitimate_urls.txt"
|
| 309 |
+
|
| 310 |
+
phish_path.write_text("\n".join(phish_urls))
|
| 311 |
+
legit_path.write_text("\n".join(legit_urls))
|
| 312 |
+
logger.info(f"Saved {len(phish_urls)} phishing URLs to {phish_path}")
|
| 313 |
+
logger.info(f"Saved {len(legit_urls)} legitimate URLs to {legit_path}")
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def url_hash(url: str) -> str:
|
| 317 |
+
"""SHA256 hash of a URL (for dedup and privacy)."""
|
| 318 |
+
return hashlib.sha256(url.encode("utf-8")).hexdigest()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# ββ Entry point ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 322 |
+
def main() -> None:
|
| 323 |
+
logging.basicConfig(
|
| 324 |
+
level=logging.INFO,
|
| 325 |
+
format="%(asctime)s | %(levelname)-7s | %(message)s",
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
print("=" * 60)
|
| 329 |
+
print("PhishGuard AI β Data Collection")
|
| 330 |
+
print("=" * 60)
|
| 331 |
+
|
| 332 |
+
# 1. PhishTank
|
| 333 |
+
phish_urls = download_phishtank()
|
| 334 |
+
print(f"\nβ
PhishTank: {len(phish_urls)} phishing URLs")
|
| 335 |
+
|
| 336 |
+
# 2. TRANCO
|
| 337 |
+
legit_urls = download_tranco(n=10000)
|
| 338 |
+
print(f"β
TRANCO: {len(legit_urls)} legitimate URLs")
|
| 339 |
+
|
| 340 |
+
# 3. Kaggle features
|
| 341 |
+
kaggle_df = download_kaggle_mirror()
|
| 342 |
+
if not kaggle_df.empty:
|
| 343 |
+
phish_count = (kaggle_df["CLASS_LABEL"] == 1).sum()
|
| 344 |
+
legit_count = (kaggle_df["CLASS_LABEL"] == 0).sum()
|
| 345 |
+
print(f"β
Kaggle: {len(kaggle_df)} rows ({phish_count} phishing, {legit_count} legit)")
|
| 346 |
+
else:
|
| 347 |
+
print("β οΈ Kaggle: download failed (will use PhishTank + TRANCO only)")
|
| 348 |
+
|
| 349 |
+
# 4. Save URL lists
|
| 350 |
+
save_url_lists(phish_urls, legit_urls)
|
| 351 |
+
|
| 352 |
+
# 5. Merge and split
|
| 353 |
+
train, val, test = merge_datasets(phish_urls, legit_urls)
|
| 354 |
+
print(f"\nπ Dataset splits:")
|
| 355 |
+
print(f" Train: {len(train)} ({sum(1 for _,l in train if l==1)} phish / {sum(1 for _,l in train if l==0)} legit)")
|
| 356 |
+
print(f" Val: {len(val)} ({sum(1 for _,l in val if l==1)} phish / {sum(1 for _,l in val if l==0)} legit)")
|
| 357 |
+
print(f" Test: {len(test)} ({sum(1 for _,l in test if l==1)} phish / {sum(1 for _,l in test if l==0)} legit)")
|
| 358 |
+
|
| 359 |
+
print(f"\nβ
All data saved to {DATA_DIR}")
|
| 360 |
+
print("=" * 60)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
if __name__ == "__main__":
|
| 364 |
+
main()
|
domain_graph_builder.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - gnn/domain_graph_builder.py
|
| 3 |
+
# Builds graph representations for GNN inference + training.
|
| 4 |
+
#
|
| 5 |
+
# Node features (12-dim per URL):
|
| 6 |
+
# [url_len_norm, domain_len_norm, subdomain_count_norm,
|
| 7 |
+
# shannon_entropy_norm, digit_ratio, hyphen_count_norm,
|
| 8 |
+
# phishing_keyword_hits_norm, suspicious_tld_binary,
|
| 9 |
+
# ip_as_hostname_binary, has_https_binary,
|
| 10 |
+
# path_depth_norm, query_string_len_norm]
|
| 11 |
+
#
|
| 12 |
+
# Edges: shared suspicious TLD + shared IP (async DNS)
|
| 13 |
+
# ============================================================
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
import math
|
| 19 |
+
import asyncio
|
| 20 |
+
import logging
|
| 21 |
+
import socket
|
| 22 |
+
from typing import Dict, List, Optional, Tuple
|
| 23 |
+
from urllib.parse import urlparse
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger("phishguard.gnn.graph_builder")
|
| 28 |
+
|
| 29 |
+
# ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
SUSPICIOUS_TLDS = frozenset({
|
| 31 |
+
".xyz", ".tk", ".ml", ".ga", ".cf",
|
| 32 |
+
".gq", ".pw", ".top", ".click",
|
| 33 |
+
})
|
| 34 |
+
|
| 35 |
+
PHISHING_KEYWORDS = frozenset({
|
| 36 |
+
"login", "verify", "secure", "update", "account",
|
| 37 |
+
"banking", "signin", "reset", "confirm", "suspend",
|
| 38 |
+
"webscr", "cmd", "payment", "alert",
|
| 39 |
+
})
|
| 40 |
+
|
| 41 |
+
_re_ip = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DomainGraphBuilder:
|
| 45 |
+
"""
|
| 46 |
+
Builds PyTorch Geometric Data objects from URL lists.
|
| 47 |
+
Each URL becomes a node with 12-dim feature vector.
|
| 48 |
+
Edges are created from shared IP addresses and shared TLDs.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self) -> None:
|
| 52 |
+
self._re_ip = _re_ip
|
| 53 |
+
|
| 54 |
+
def extract_node_features(self, url: str) -> np.ndarray:
|
| 55 |
+
"""
|
| 56 |
+
Extract 12-dim feature vector from a URL.
|
| 57 |
+
|
| 58 |
+
Returns np.ndarray of shape (12,) with values in [0, 1].
|
| 59 |
+
"""
|
| 60 |
+
try:
|
| 61 |
+
parsed = urlparse(url if "://" in url else f"http://{url}")
|
| 62 |
+
except Exception:
|
| 63 |
+
return np.zeros(12, dtype=np.float32)
|
| 64 |
+
|
| 65 |
+
hostname: str = (parsed.hostname or "").lower()
|
| 66 |
+
path: str = parsed.path or ""
|
| 67 |
+
query: str = parsed.query or ""
|
| 68 |
+
scheme: str = parsed.scheme or ""
|
| 69 |
+
|
| 70 |
+
# 1. url_len_norm (normalized by 500)
|
| 71 |
+
url_len_norm = min(len(url) / 500.0, 1.0)
|
| 72 |
+
|
| 73 |
+
# 2. domain_len_norm (normalized by 100)
|
| 74 |
+
domain_len_norm = min(len(hostname) / 100.0, 1.0)
|
| 75 |
+
|
| 76 |
+
# 3. subdomain_count_norm
|
| 77 |
+
parts = hostname.split(".")
|
| 78 |
+
subdomain_count = max(0, len(parts) - 2)
|
| 79 |
+
subdomain_count_norm = min(subdomain_count / 10.0, 1.0)
|
| 80 |
+
|
| 81 |
+
# 4. shannon_entropy_norm (normalized by 5.0)
|
| 82 |
+
entropy = self._shannon_entropy(hostname)
|
| 83 |
+
shannon_entropy_norm = min(entropy / 5.0, 1.0)
|
| 84 |
+
|
| 85 |
+
# 5. digit_ratio
|
| 86 |
+
digit_ratio = 0.0
|
| 87 |
+
if hostname:
|
| 88 |
+
digits = sum(1 for c in hostname if c.isdigit())
|
| 89 |
+
digit_ratio = digits / len(hostname)
|
| 90 |
+
|
| 91 |
+
# 6. hyphen_count_norm
|
| 92 |
+
hyphen_count = hostname.count("-")
|
| 93 |
+
hyphen_count_norm = min(hyphen_count / 10.0, 1.0)
|
| 94 |
+
|
| 95 |
+
# 7. phishing_keyword_hits_norm
|
| 96 |
+
url_lower = url.lower()
|
| 97 |
+
keyword_hits = sum(1 for kw in PHISHING_KEYWORDS if kw in url_lower)
|
| 98 |
+
phishing_keyword_hits_norm = min(keyword_hits / 5.0, 1.0)
|
| 99 |
+
|
| 100 |
+
# 8. suspicious_tld_binary
|
| 101 |
+
suspicious_tld_binary = 0.0
|
| 102 |
+
for tld in SUSPICIOUS_TLDS:
|
| 103 |
+
if hostname.endswith(tld):
|
| 104 |
+
suspicious_tld_binary = 1.0
|
| 105 |
+
break
|
| 106 |
+
|
| 107 |
+
# 9. ip_as_hostname_binary
|
| 108 |
+
ip_as_hostname_binary = 1.0 if self._re_ip.match(hostname) else 0.0
|
| 109 |
+
|
| 110 |
+
# 10. has_https_binary
|
| 111 |
+
has_https_binary = 1.0 if scheme == "https" else 0.0
|
| 112 |
+
|
| 113 |
+
# 11. path_depth_norm
|
| 114 |
+
path_segments = [s for s in path.split("/") if s]
|
| 115 |
+
path_depth_norm = min(len(path_segments) / 10.0, 1.0)
|
| 116 |
+
|
| 117 |
+
# 12. query_string_len_norm
|
| 118 |
+
query_string_len_norm = min(len(query) / 500.0, 1.0)
|
| 119 |
+
|
| 120 |
+
features = np.array([
|
| 121 |
+
url_len_norm,
|
| 122 |
+
domain_len_norm,
|
| 123 |
+
subdomain_count_norm,
|
| 124 |
+
shannon_entropy_norm,
|
| 125 |
+
digit_ratio,
|
| 126 |
+
hyphen_count_norm,
|
| 127 |
+
phishing_keyword_hits_norm,
|
| 128 |
+
suspicious_tld_binary,
|
| 129 |
+
ip_as_hostname_binary,
|
| 130 |
+
has_https_binary,
|
| 131 |
+
path_depth_norm,
|
| 132 |
+
query_string_len_norm,
|
| 133 |
+
], dtype=np.float32)
|
| 134 |
+
|
| 135 |
+
return features
|
| 136 |
+
|
| 137 |
+
def _shannon_entropy(self, s: str) -> float:
|
| 138 |
+
"""Compute Shannon entropy of a string."""
|
| 139 |
+
if not s:
|
| 140 |
+
return 0.0
|
| 141 |
+
length = len(s)
|
| 142 |
+
freq: Dict[str, int] = {}
|
| 143 |
+
for c in s:
|
| 144 |
+
freq[c] = freq.get(c, 0) + 1
|
| 145 |
+
return -sum(
|
| 146 |
+
(count / length) * math.log2(count / length)
|
| 147 |
+
for count in freq.values()
|
| 148 |
+
if count > 0
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
async def _resolve_ips(self, domains: List[str]) -> Dict[str, str]:
|
| 152 |
+
"""
|
| 153 |
+
Async DNS resolution for a list of domains.
|
| 154 |
+
Returns dict mapping domain β IP address.
|
| 155 |
+
"""
|
| 156 |
+
results: Dict[str, str] = {}
|
| 157 |
+
loop = asyncio.get_event_loop()
|
| 158 |
+
|
| 159 |
+
async def resolve_one(domain: str) -> Tuple[str, str]:
|
| 160 |
+
try:
|
| 161 |
+
ip = await asyncio.wait_for(
|
| 162 |
+
loop.run_in_executor(None, socket.gethostbyname, domain),
|
| 163 |
+
timeout=2.0,
|
| 164 |
+
)
|
| 165 |
+
return domain, ip
|
| 166 |
+
except Exception:
|
| 167 |
+
return domain, ""
|
| 168 |
+
|
| 169 |
+
tasks = [resolve_one(d) for d in domains]
|
| 170 |
+
resolved = await asyncio.gather(*tasks, return_exceptions=True)
|
| 171 |
+
for item in resolved:
|
| 172 |
+
if isinstance(item, tuple):
|
| 173 |
+
domain, ip = item
|
| 174 |
+
if ip:
|
| 175 |
+
results[domain] = ip
|
| 176 |
+
return results
|
| 177 |
+
|
| 178 |
+
def _add_shared_ip_edges(
|
| 179 |
+
self, domains: List[str], ips: Dict[str, str]
|
| 180 |
+
) -> List[Tuple[int, int]]:
|
| 181 |
+
"""
|
| 182 |
+
Create edges between nodes that share the same IP address.
|
| 183 |
+
Returns list of (src, dst) index pairs.
|
| 184 |
+
"""
|
| 185 |
+
edges: List[Tuple[int, int]] = []
|
| 186 |
+
# Group domain indices by IP
|
| 187 |
+
ip_to_indices: Dict[str, List[int]] = {}
|
| 188 |
+
for idx, domain in enumerate(domains):
|
| 189 |
+
ip = ips.get(domain, "")
|
| 190 |
+
if ip:
|
| 191 |
+
ip_to_indices.setdefault(ip, []).append(idx)
|
| 192 |
+
|
| 193 |
+
# Create edges between all nodes sharing an IP
|
| 194 |
+
for ip, indices in ip_to_indices.items():
|
| 195 |
+
for i in range(len(indices)):
|
| 196 |
+
for j in range(i + 1, len(indices)):
|
| 197 |
+
edges.append((indices[i], indices[j]))
|
| 198 |
+
edges.append((indices[j], indices[i])) # bidirectional
|
| 199 |
+
|
| 200 |
+
return edges
|
| 201 |
+
|
| 202 |
+
def _add_shared_tld_edges(self, domains: List[str]) -> List[Tuple[int, int]]:
|
| 203 |
+
"""
|
| 204 |
+
Create edges between nodes that share the same suspicious TLD.
|
| 205 |
+
"""
|
| 206 |
+
edges: List[Tuple[int, int]] = []
|
| 207 |
+
tld_to_indices: Dict[str, List[int]] = {}
|
| 208 |
+
|
| 209 |
+
for idx, domain in enumerate(domains):
|
| 210 |
+
for tld in SUSPICIOUS_TLDS:
|
| 211 |
+
if domain.endswith(tld):
|
| 212 |
+
tld_to_indices.setdefault(tld, []).append(idx)
|
| 213 |
+
break
|
| 214 |
+
|
| 215 |
+
for tld, indices in tld_to_indices.items():
|
| 216 |
+
for i in range(len(indices)):
|
| 217 |
+
for j in range(i + 1, len(indices)):
|
| 218 |
+
edges.append((indices[i], indices[j]))
|
| 219 |
+
edges.append((indices[j], indices[i]))
|
| 220 |
+
|
| 221 |
+
return edges
|
| 222 |
+
|
| 223 |
+
def build_graph(self, urls: List[str], resolve_dns: bool = False) -> dict:
|
| 224 |
+
"""
|
| 225 |
+
Build a graph dict from a list of URLs.
|
| 226 |
+
|
| 227 |
+
Returns dict with:
|
| 228 |
+
- features: np.ndarray of shape (N, 12)
|
| 229 |
+
- edges: List of (src, dst) pairs
|
| 230 |
+
- node_count: int
|
| 231 |
+
- edge_count: int
|
| 232 |
+
- domains: List[str]
|
| 233 |
+
"""
|
| 234 |
+
if not urls:
|
| 235 |
+
return {
|
| 236 |
+
"features": np.zeros((1, 12), dtype=np.float32),
|
| 237 |
+
"edges": [],
|
| 238 |
+
"node_count": 0,
|
| 239 |
+
"edge_count": 0,
|
| 240 |
+
"domains": [],
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
# Extract features for each URL
|
| 244 |
+
features = np.array(
|
| 245 |
+
[self.extract_node_features(url) for url in urls],
|
| 246 |
+
dtype=np.float32,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Extract domains
|
| 250 |
+
domains: List[str] = []
|
| 251 |
+
for url in urls:
|
| 252 |
+
try:
|
| 253 |
+
parsed = urlparse(url if "://" in url else f"http://{url}")
|
| 254 |
+
domains.append((parsed.hostname or "").lower())
|
| 255 |
+
except Exception:
|
| 256 |
+
domains.append("")
|
| 257 |
+
|
| 258 |
+
# Build edges from shared TLDs (synchronous, fast)
|
| 259 |
+
edges = self._add_shared_tld_edges(domains)
|
| 260 |
+
|
| 261 |
+
# Optionally resolve DNS for shared IP edges
|
| 262 |
+
if resolve_dns and len(domains) > 1:
|
| 263 |
+
try:
|
| 264 |
+
loop = asyncio.get_event_loop()
|
| 265 |
+
if loop.is_running():
|
| 266 |
+
# Already in async context
|
| 267 |
+
pass
|
| 268 |
+
else:
|
| 269 |
+
ips = loop.run_until_complete(self._resolve_ips(domains))
|
| 270 |
+
edges.extend(self._add_shared_ip_edges(domains, ips))
|
| 271 |
+
except RuntimeError:
|
| 272 |
+
pass # Cannot resolve in this context
|
| 273 |
+
|
| 274 |
+
return {
|
| 275 |
+
"features": features,
|
| 276 |
+
"edges": edges,
|
| 277 |
+
"node_count": len(urls),
|
| 278 |
+
"edge_count": len(edges),
|
| 279 |
+
"domains": domains,
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
def build_single_node_graph(self, url: str) -> dict:
|
| 283 |
+
"""
|
| 284 |
+
Build a single-node graph for MLP fallback path.
|
| 285 |
+
Used when a graph has fewer than 2 nodes.
|
| 286 |
+
"""
|
| 287 |
+
features = self.extract_node_features(url).reshape(1, -1)
|
| 288 |
+
return {
|
| 289 |
+
"features": features,
|
| 290 |
+
"edges": [],
|
| 291 |
+
"node_count": 1,
|
| 292 |
+
"edge_count": 0,
|
| 293 |
+
"domains": [url],
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# ββ Legacy compatibility wrapper βββββββββββββββββββββββββββββββββββββ
|
| 298 |
+
_builder = DomainGraphBuilder()
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def build_domain_graph(urls: List[str]) -> dict:
|
| 302 |
+
"""Legacy wrapper for backward compatibility."""
|
| 303 |
+
return _builder.build_graph(urls)
|
email_analyzer.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - email_analyzer.py
|
| 3 |
+
# Analyzes raw emails for phishing indicators.
|
| 4 |
+
# Checks: sender authentication (SPF/DKIM/DMARC),
|
| 5 |
+
# brand spoofing, urgency language, and embedded links.
|
| 6 |
+
#
|
| 7 |
+
# Reuses BERT model from bert_analyzer to avoid duplicate loading.
|
| 8 |
+
# ============================================================
|
| 9 |
+
|
| 10 |
+
import email
|
| 11 |
+
import re
|
| 12 |
+
from email import policy
|
| 13 |
+
from email.parser import BytesParser, Parser
|
| 14 |
+
|
| 15 |
+
# Reuse the NLP analyzer from bert_analyzer
|
| 16 |
+
from bert_analyzer import analyze_text as bert_analyze_text, _ensure_bert_loaded
|
| 17 |
+
import bert_analyzer
|
| 18 |
+
|
| 19 |
+
print("[PhishGuard] Email analyzer initialized (reusing shared NLP)")
|
| 20 |
+
|
| 21 |
+
URGENCY_PATTERNS = [
|
| 22 |
+
r'(act now|immediate action|urgent|verify immediately|account suspended)',
|
| 23 |
+
r'(click here to (verify|confirm|update|restore))',
|
| 24 |
+
r'(your account (will be|has been) (suspended|closed|deactivated))',
|
| 25 |
+
r'(limited time|expires in \d+ hours?)',
|
| 26 |
+
r'(unusual (sign-in|login|activity) detected)',
|
| 27 |
+
r'(confirm your (identity|password|email|account))',
|
| 28 |
+
r'(we noticed (suspicious|unusual|unauthorized))',
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
BRAND_SPOOFS = [
|
| 32 |
+
'paypal','amazon','apple','microsoft','google','netflix',
|
| 33 |
+
'facebook','instagram','linkedin','twitter','chase','wellsfargo',
|
| 34 |
+
'bankofamerica','citibank','irs','fedex','ups','dhl',
|
| 35 |
+
'dropbox','docusign','zoom','office365','hdfc','icici','sbi'
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def parse_email_msg(raw):
|
| 40 |
+
"""Parse raw email bytes or string into an email.message object."""
|
| 41 |
+
if isinstance(raw, bytes):
|
| 42 |
+
return BytesParser(policy=policy.default).parsebytes(raw)
|
| 43 |
+
return Parser(policy=policy.default).parsestr(raw)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def extract_urls(text: str) -> list:
|
| 47 |
+
"""Extract all unique HTTP/HTTPS URLs from text."""
|
| 48 |
+
return list(set(re.findall(r'https?://[^\s<>"\'\\ ]+', text)))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_body(msg) -> str:
|
| 52 |
+
"""Extract plain text body from email message, falling back to HTML stripped of tags."""
|
| 53 |
+
parts = []
|
| 54 |
+
if msg.is_multipart():
|
| 55 |
+
for part in msg.walk():
|
| 56 |
+
ct = part.get_content_type()
|
| 57 |
+
if ct == 'text/plain':
|
| 58 |
+
try: parts.append(part.get_content())
|
| 59 |
+
except: pass
|
| 60 |
+
elif ct == 'text/html' and not parts:
|
| 61 |
+
try: parts.append(re.sub(r'<[^>]+>', ' ', part.get_content()))
|
| 62 |
+
except: pass
|
| 63 |
+
else:
|
| 64 |
+
try: parts.append(msg.get_content())
|
| 65 |
+
except: pass
|
| 66 |
+
return ' '.join(parts)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def check_sender_auth(msg) -> dict:
|
| 70 |
+
"""
|
| 71 |
+
Check email authentication headers:
|
| 72 |
+
- SPF (Sender Policy Framework)
|
| 73 |
+
- DKIM (DomainKeys Identified Mail)
|
| 74 |
+
- DMARC (Domain-based Message Authentication)
|
| 75 |
+
- From/Return-Path domain mismatch
|
| 76 |
+
- Free email provider usage
|
| 77 |
+
"""
|
| 78 |
+
auth = msg.get('Authentication-Results', '').lower()
|
| 79 |
+
spf_raw = msg.get('Received-SPF', '').lower()
|
| 80 |
+
spf_pass = 'spf=pass' in auth or 'pass' in spf_raw
|
| 81 |
+
dkim_pass = 'dkim=pass' in auth
|
| 82 |
+
dmarc_pass= 'dmarc=pass'in auth
|
| 83 |
+
|
| 84 |
+
from_addr = msg.get('From', '')
|
| 85 |
+
return_path = msg.get('Return-Path', '')
|
| 86 |
+
from_dom = re.search(r'@([\w.-]+)', from_addr)
|
| 87 |
+
ret_dom = re.search(r'@([\w.-]+)', return_path)
|
| 88 |
+
mismatch = bool(from_dom and ret_dom and
|
| 89 |
+
from_dom.group(1) != ret_dom.group(1))
|
| 90 |
+
|
| 91 |
+
free = {'gmail.com','yahoo.com','hotmail.com','outlook.com','protonmail.com'}
|
| 92 |
+
using_free = (from_dom.group(1).lower() in free) if from_dom else False
|
| 93 |
+
|
| 94 |
+
risk = 0
|
| 95 |
+
if not spf_pass: risk += 25
|
| 96 |
+
if not dkim_pass: risk += 20
|
| 97 |
+
if not dmarc_pass: risk += 15
|
| 98 |
+
if mismatch: risk += 30
|
| 99 |
+
if using_free: risk += 10
|
| 100 |
+
|
| 101 |
+
return {
|
| 102 |
+
"spf_pass": spf_pass, "dkim_pass": dkim_pass,
|
| 103 |
+
"dmarc_pass": dmarc_pass, "domain_mismatch": mismatch,
|
| 104 |
+
"using_free_email": using_free,
|
| 105 |
+
"auth_risk_score": min(risk, 100)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def check_brand_spoofing(subject: str, body: str, sender: str) -> dict:
|
| 110 |
+
"""Detect brand names mentioned in email content but not matching sender domain."""
|
| 111 |
+
combined = (subject + ' ' + body + ' ' + sender).lower()
|
| 112 |
+
sender_dom = re.search(r'@([\w.-]+)', sender)
|
| 113 |
+
s_dom = sender_dom.group(1).lower() if sender_dom else ''
|
| 114 |
+
spoofed = [b for b in BRAND_SPOOFS
|
| 115 |
+
if b in combined and b not in s_dom]
|
| 116 |
+
return {
|
| 117 |
+
"brand_spoof_detected": bool(spoofed),
|
| 118 |
+
"spoofed_brands": spoofed
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def check_urgency(text: str) -> dict:
|
| 123 |
+
"""Detect urgency/pressure language patterns typical of phishing emails."""
|
| 124 |
+
matches = []
|
| 125 |
+
for pat in URGENCY_PATTERNS:
|
| 126 |
+
found = re.findall(pat, text.lower())
|
| 127 |
+
matches.extend(found)
|
| 128 |
+
return {
|
| 129 |
+
"urgency_detected": bool(matches),
|
| 130 |
+
"urgency_matches": [str(m) for m in matches[:5]],
|
| 131 |
+
"urgency_score": min(len(matches) * 15, 60)
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def bert_score(text: str) -> float:
|
| 136 |
+
"""Run NLP classifier on email text and return phishing probability."""
|
| 137 |
+
if not text.strip():
|
| 138 |
+
return 0.1
|
| 139 |
+
try:
|
| 140 |
+
_ensure_bert_loaded()
|
| 141 |
+
if bert_analyzer._use_bert and bert_analyzer._classifier is not None:
|
| 142 |
+
result = bert_analyzer._classifier(text[:512])[0]
|
| 143 |
+
label = result['label'].upper()
|
| 144 |
+
score = result['score']
|
| 145 |
+
return score if ('SPAM' in label or label == 'LABEL_1') else 1 - score
|
| 146 |
+
else:
|
| 147 |
+
# Use keyword analysis from bert_analyzer
|
| 148 |
+
result = bert_analyze_text("", "", text)
|
| 149 |
+
return result.get("bert_phishing_prob", 0.3)
|
| 150 |
+
except:
|
| 151 |
+
return 0.3
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def analyze_email(raw, return_urls: bool = True) -> dict:
|
| 155 |
+
"""
|
| 156 |
+
Full phishing analysis of a raw email.
|
| 157 |
+
Pass raw bytes or a string of the full email.
|
| 158 |
+
|
| 159 |
+
Combines: BERT NLP score + sender auth + brand spoofing + urgency detection.
|
| 160 |
+
"""
|
| 161 |
+
msg = parse_email_msg(raw)
|
| 162 |
+
subject = msg.get('Subject', '')
|
| 163 |
+
sender = msg.get('From', '')
|
| 164 |
+
body = get_body(msg)
|
| 165 |
+
urls = extract_urls(body)
|
| 166 |
+
|
| 167 |
+
auth = check_sender_auth(msg)
|
| 168 |
+
brand = check_brand_spoofing(subject, body, sender)
|
| 169 |
+
urgency = check_urgency(subject + ' ' + body)
|
| 170 |
+
bert_p = bert_score(subject + '. ' + body[:400])
|
| 171 |
+
|
| 172 |
+
raw_score = (bert_p * 40 +
|
| 173 |
+
auth['auth_risk_score'] * 0.30 +
|
| 174 |
+
urgency['urgency_score'] * 0.20 +
|
| 175 |
+
(30 if brand['brand_spoof_detected'] else 0) * 0.10)
|
| 176 |
+
final = min(raw_score / 100, 1.0)
|
| 177 |
+
|
| 178 |
+
result = {
|
| 179 |
+
"is_phishing": final > 0.60,
|
| 180 |
+
"phishing_probability": round(final, 4),
|
| 181 |
+
"subject": subject,
|
| 182 |
+
"sender": sender,
|
| 183 |
+
"auth_analysis": auth,
|
| 184 |
+
"brand_analysis": brand,
|
| 185 |
+
"urgency_analysis": urgency,
|
| 186 |
+
"bert_score": round(bert_p, 4),
|
| 187 |
+
"extracted_url_count": len(urls),
|
| 188 |
+
}
|
| 189 |
+
if return_urls:
|
| 190 |
+
result["extracted_urls"] = urls[:20]
|
| 191 |
+
return result
|
feedback_store.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - feedback_store.py
|
| 3 |
+
# Thread-safe feedback storage, retraining trigger, analytics.
|
| 4 |
+
#
|
| 5 |
+
# Storage: feedback_data.jsonl (append-only, one JSON per line)
|
| 6 |
+
# Lock: asyncio.Lock prevents concurrent writes & double-retrain
|
| 7 |
+
# ============================================================
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import time
|
| 14 |
+
import asyncio
|
| 15 |
+
import shutil
|
| 16 |
+
import logging
|
| 17 |
+
from datetime import datetime, timezone
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger("phishguard.feedback")
|
| 21 |
+
|
| 22 |
+
_BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 23 |
+
FEEDBACK_FILE = os.path.join(_BASE_DIR, "feedback_data.jsonl")
|
| 24 |
+
STATE_FILE = os.path.join(_BASE_DIR, "retrain_state.json")
|
| 25 |
+
|
| 26 |
+
# ββ Async lock for thread-safe writes ββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
_write_lock = asyncio.Lock()
|
| 28 |
+
|
| 29 |
+
# ββ Retrain state (persisted to retrain_state.json) ββββββββββββββββββββββββββ
|
| 30 |
+
_retrain_state = {
|
| 31 |
+
"model_version": 1,
|
| 32 |
+
"total_feedback": 0,
|
| 33 |
+
"unprocessed_count": 0,
|
| 34 |
+
"phishing_corrections": 0,
|
| 35 |
+
"safe_corrections": 0,
|
| 36 |
+
"last_retrain": None, # ISO 8601 timestamp
|
| 37 |
+
"retrain_history": [], # [{ts, samples, accuracy, version}]
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _load_state():
|
| 42 |
+
"""Load persisted retrain state from disk."""
|
| 43 |
+
global _retrain_state
|
| 44 |
+
if os.path.exists(STATE_FILE):
|
| 45 |
+
try:
|
| 46 |
+
with open(STATE_FILE, "r") as f:
|
| 47 |
+
saved = json.load(f)
|
| 48 |
+
_retrain_state.update(saved)
|
| 49 |
+
logger.info(f"[FeedbackStore] State loaded | version={_retrain_state['model_version']} | total={_retrain_state['total_feedback']}")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.warning(f"[FeedbackStore] Could not load state: {e}")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _save_state():
|
| 55 |
+
"""Persist retrain state to disk (atomic write)."""
|
| 56 |
+
try:
|
| 57 |
+
tmp = STATE_FILE + ".tmp"
|
| 58 |
+
with open(tmp, "w") as f:
|
| 59 |
+
json.dump(_retrain_state, f, indent=2, default=str)
|
| 60 |
+
os.replace(tmp, STATE_FILE)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.warning(f"[FeedbackStore] Could not save state: {e}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Load state on module import
|
| 66 |
+
_load_state()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
# FEEDBACK STORAGE
|
| 71 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
+
|
| 73 |
+
async def append_feedback(
|
| 74 |
+
url: str,
|
| 75 |
+
label: str,
|
| 76 |
+
source: str = "user_feedback",
|
| 77 |
+
original_prediction: Optional[float] = None,
|
| 78 |
+
) -> dict:
|
| 79 |
+
"""
|
| 80 |
+
Thread-safe append of a feedback entry to feedback_data.jsonl.
|
| 81 |
+
|
| 82 |
+
Returns: {"success": True, "feedback_count": N, "unprocessed": M}
|
| 83 |
+
"""
|
| 84 |
+
entry = {
|
| 85 |
+
"url": url,
|
| 86 |
+
"label": label, # "phishing" or "safe"
|
| 87 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 88 |
+
"source": source,
|
| 89 |
+
"original_prediction": round(original_prediction, 4) if original_prediction is not None else None,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
async with _write_lock:
|
| 93 |
+
try:
|
| 94 |
+
with open(FEEDBACK_FILE, "a") as f:
|
| 95 |
+
f.write(json.dumps(entry) + "\n")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"[FeedbackStore] Write failed: {e}")
|
| 98 |
+
return {"success": False, "error": str(e)}
|
| 99 |
+
|
| 100 |
+
# Update in-memory state
|
| 101 |
+
_retrain_state["total_feedback"] += 1
|
| 102 |
+
_retrain_state["unprocessed_count"] += 1
|
| 103 |
+
if label == "phishing":
|
| 104 |
+
_retrain_state["phishing_corrections"] += 1
|
| 105 |
+
elif label == "safe":
|
| 106 |
+
_retrain_state["safe_corrections"] += 1
|
| 107 |
+
|
| 108 |
+
_save_state()
|
| 109 |
+
|
| 110 |
+
logger.info(f"[FeedbackStore] Saved | url={url} | label={label} | total={_retrain_state['total_feedback']}")
|
| 111 |
+
|
| 112 |
+
return {
|
| 113 |
+
"success": True,
|
| 114 |
+
"feedback_count": _retrain_state["total_feedback"],
|
| 115 |
+
"unprocessed": _retrain_state["unprocessed_count"],
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_unprocessed_count() -> int:
|
| 120 |
+
"""Number of feedback entries since last retraining."""
|
| 121 |
+
return _retrain_state["unprocessed_count"]
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_model_version() -> int:
|
| 125 |
+
"""Current model version number."""
|
| 126 |
+
return _retrain_state["model_version"]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_stats() -> dict:
|
| 130 |
+
"""Return feedback analytics for the /feedback/stats endpoint."""
|
| 131 |
+
return {
|
| 132 |
+
"total_feedback": _retrain_state["total_feedback"],
|
| 133 |
+
"phishing_corrections": _retrain_state["phishing_corrections"],
|
| 134 |
+
"safe_corrections": _retrain_state["safe_corrections"],
|
| 135 |
+
"unprocessed_count": _retrain_state["unprocessed_count"],
|
| 136 |
+
"last_retrain": _retrain_state["last_retrain"],
|
| 137 |
+
"model_version": _retrain_state["model_version"],
|
| 138 |
+
"retrain_history": _retrain_state["retrain_history"][-10:], # last 10
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_recent_entries(n: int = 50) -> list:
|
| 143 |
+
"""Read the last N feedback entries from the JSONL file."""
|
| 144 |
+
if not os.path.exists(FEEDBACK_FILE):
|
| 145 |
+
return []
|
| 146 |
+
try:
|
| 147 |
+
with open(FEEDBACK_FILE, "r") as f:
|
| 148 |
+
lines = f.readlines()
|
| 149 |
+
entries = []
|
| 150 |
+
for line in lines[-(n):]:
|
| 151 |
+
line = line.strip()
|
| 152 |
+
if line:
|
| 153 |
+
entries.append(json.loads(line))
|
| 154 |
+
return entries
|
| 155 |
+
except Exception:
|
| 156 |
+
return []
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 160 |
+
# RETRAINING PIPELINE
|
| 161 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 162 |
+
|
| 163 |
+
RETRAIN_THRESHOLD = 50
|
| 164 |
+
_retrain_running = False
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def should_retrain() -> bool:
|
| 168 |
+
"""Check if retraining should be triggered."""
|
| 169 |
+
return (
|
| 170 |
+
_retrain_state["unprocessed_count"] >= RETRAIN_THRESHOLD
|
| 171 |
+
and not _retrain_running
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def mark_retrain_complete(samples: int, accuracy: float):
|
| 176 |
+
"""
|
| 177 |
+
Called after successful retraining.
|
| 178 |
+
Increments model_version, resets unprocessed counter, logs history.
|
| 179 |
+
"""
|
| 180 |
+
_retrain_state["model_version"] += 1
|
| 181 |
+
_retrain_state["unprocessed_count"] = 0
|
| 182 |
+
_retrain_state["last_retrain"] = datetime.now(timezone.utc).isoformat()
|
| 183 |
+
_retrain_state["retrain_history"].append({
|
| 184 |
+
"timestamp": _retrain_state["last_retrain"],
|
| 185 |
+
"samples": samples,
|
| 186 |
+
"accuracy": round(accuracy, 4),
|
| 187 |
+
"version": _retrain_state["model_version"],
|
| 188 |
+
})
|
| 189 |
+
# Keep only last 50 history entries
|
| 190 |
+
if len(_retrain_state["retrain_history"]) > 50:
|
| 191 |
+
_retrain_state["retrain_history"] = _retrain_state["retrain_history"][-50:]
|
| 192 |
+
_save_state()
|
| 193 |
+
logger.info(
|
| 194 |
+
f"[FeedbackStore] Retrained on {samples} feedback samples. "
|
| 195 |
+
f"New accuracy: {accuracy:.2%}. Model version: {_retrain_state['model_version']}"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def archive_feedback_file():
|
| 200 |
+
"""Move the processed feedback file to a timestamped backup."""
|
| 201 |
+
if os.path.exists(FEEDBACK_FILE):
|
| 202 |
+
archive = FEEDBACK_FILE + f".{int(time.time())}.bak"
|
| 203 |
+
try:
|
| 204 |
+
shutil.move(FEEDBACK_FILE, archive)
|
| 205 |
+
logger.info(f"[FeedbackStore] Archived feedback β {archive}")
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.warning(f"[FeedbackStore] Archive failed: {e}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def load_feedback_entries() -> list:
|
| 211 |
+
"""Load ALL entries from the feedback JSONL file."""
|
| 212 |
+
if not os.path.exists(FEEDBACK_FILE):
|
| 213 |
+
return []
|
| 214 |
+
entries = []
|
| 215 |
+
try:
|
| 216 |
+
with open(FEEDBACK_FILE, "r") as f:
|
| 217 |
+
for line in f:
|
| 218 |
+
line = line.strip()
|
| 219 |
+
if line:
|
| 220 |
+
entries.append(json.loads(line))
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"[FeedbackStore] Read failed: {e}")
|
| 223 |
+
return entries
|
generate_icons.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate PhishGuard extension icons at 16, 48, 128 px using pure Python."""
|
| 3 |
+
import struct, zlib, os
|
| 4 |
+
|
| 5 |
+
def create_png(width, height, pixels):
|
| 6 |
+
"""Create a minimal PNG from RGBA pixel data."""
|
| 7 |
+
def chunk(chunk_type, data):
|
| 8 |
+
c = chunk_type + data
|
| 9 |
+
return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
|
| 10 |
+
|
| 11 |
+
raw = b''
|
| 12 |
+
for y in range(height):
|
| 13 |
+
raw += b'\x00'
|
| 14 |
+
for x in range(width):
|
| 15 |
+
idx = (y * width + x) * 4
|
| 16 |
+
raw += bytes(pixels[idx:idx+4])
|
| 17 |
+
|
| 18 |
+
return (b'\x89PNG\r\n\x1a\n' +
|
| 19 |
+
chunk(b'IHDR', struct.pack('>IIBBBBB', width, height, 8, 6, 0, 0, 0)) +
|
| 20 |
+
chunk(b'IDAT', zlib.compress(raw)) +
|
| 21 |
+
chunk(b'IEND', b''))
|
| 22 |
+
|
| 23 |
+
def draw_shield_icon(size):
|
| 24 |
+
"""Draw a shield icon with checkmark."""
|
| 25 |
+
pixels = [0] * (size * size * 4)
|
| 26 |
+
cx, cy = size / 2, size / 2
|
| 27 |
+
sr, sg, sb = 0x53, 0x4A, 0xB7
|
| 28 |
+
hr, hg, hb = 0x7B, 0x73, 0xD4
|
| 29 |
+
|
| 30 |
+
for y in range(size):
|
| 31 |
+
for x in range(size):
|
| 32 |
+
idx = (y * size + x) * 4
|
| 33 |
+
nx = (x - cx) / (size / 2)
|
| 34 |
+
ny = (y - cy) / (size / 2)
|
| 35 |
+
|
| 36 |
+
in_shield = False
|
| 37 |
+
if ny < -0.05:
|
| 38 |
+
if abs(nx) < 0.75:
|
| 39 |
+
in_shield = True
|
| 40 |
+
elif ny < 0.5:
|
| 41 |
+
w = 0.75 * (1 - ny * 0.8)
|
| 42 |
+
if abs(nx) < w:
|
| 43 |
+
in_shield = True
|
| 44 |
+
else:
|
| 45 |
+
w = 0.75 * max(0, (1.0 - ny) * 1.4)
|
| 46 |
+
if abs(nx) < w:
|
| 47 |
+
in_shield = True
|
| 48 |
+
|
| 49 |
+
if ny < -0.8:
|
| 50 |
+
in_shield = False
|
| 51 |
+
|
| 52 |
+
if in_shield:
|
| 53 |
+
blend = max(0, min(1, 0.5 - nx * 0.3 - ny * 0.2))
|
| 54 |
+
r = int(sr + (hr - sr) * blend)
|
| 55 |
+
g = int(sg + (hg - sg) * blend)
|
| 56 |
+
b = int(sb + (hb - sb) * blend)
|
| 57 |
+
pixels[idx:idx+4] = [r, g, b, 255]
|
| 58 |
+
else:
|
| 59 |
+
pixels[idx:idx+4] = [0, 0, 0, 0]
|
| 60 |
+
|
| 61 |
+
if size >= 32:
|
| 62 |
+
check_points = []
|
| 63 |
+
for t in range(100):
|
| 64 |
+
p = t / 100.0
|
| 65 |
+
if p < 0.4:
|
| 66 |
+
pp = p / 0.4
|
| 67 |
+
px = int(cx + (-0.25 + pp * 0.25) * size * 0.6)
|
| 68 |
+
py = int(cy + (-0.1 + pp * 0.3) * size * 0.6)
|
| 69 |
+
else:
|
| 70 |
+
pp = (p - 0.4) / 0.6
|
| 71 |
+
px = int(cx + (0.0 + pp * 0.35) * size * 0.6)
|
| 72 |
+
py = int(cy + (0.2 - pp * 0.45) * size * 0.6)
|
| 73 |
+
check_points.append((px, py))
|
| 74 |
+
|
| 75 |
+
thickness = max(1, int(size * 0.06))
|
| 76 |
+
for px, py in check_points:
|
| 77 |
+
for dy in range(-thickness, thickness+1):
|
| 78 |
+
for dx in range(-thickness, thickness+1):
|
| 79 |
+
xx, yy = px + dx, py + dy
|
| 80 |
+
if 0 <= xx < size and 0 <= yy < size:
|
| 81 |
+
idx = (yy * size + xx) * 4
|
| 82 |
+
if pixels[idx+3] > 0:
|
| 83 |
+
pixels[idx:idx+4] = [255, 255, 255, 240]
|
| 84 |
+
|
| 85 |
+
return pixels
|
| 86 |
+
|
| 87 |
+
icons_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'extension', 'icons')
|
| 88 |
+
os.makedirs(icons_dir, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
for size in [16, 48, 128]:
|
| 91 |
+
pixels = draw_shield_icon(size)
|
| 92 |
+
png_data = create_png(size, size, pixels)
|
| 93 |
+
path = os.path.join(icons_dir, f'icon{size}.png')
|
| 94 |
+
with open(path, 'wb') as f:
|
| 95 |
+
f.write(png_data)
|
| 96 |
+
print(f"Created {path} ({len(png_data)} bytes)")
|
| 97 |
+
|
| 98 |
+
print("Done! All icons generated.")
|
gmail_scanner.js
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// gmail_scanner.js
|
| 2 |
+
|
| 3 |
+
console.log("PhishGuard AI: Gmail Scanner loaded.");
|
| 4 |
+
|
| 5 |
+
// Gmail's DOM can change, but commonly the email text is stored in elements with class '.a3s' or '.ii.gt'
|
| 6 |
+
const EMAIL_BODY_SELECTOR = '.a3s, .ii.gt';
|
| 7 |
+
|
| 8 |
+
// Function to inject a visible warning banner into the Gmail UI
|
| 9 |
+
function injectWarningBanner(emailContainer, message) {
|
| 10 |
+
// Prevent duplicate banners if one is already injected
|
| 11 |
+
if (emailContainer.parentNode.querySelector('.phishguard-banner')) {
|
| 12 |
+
return;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
const banner = document.createElement('div');
|
| 16 |
+
banner.className = 'phishguard-banner';
|
| 17 |
+
|
| 18 |
+
// Styling the banner to look native but urgent, fitting Google's Material Design
|
| 19 |
+
banner.style.backgroundColor = '#fce8e6';
|
| 20 |
+
banner.style.color = '#c5221f';
|
| 21 |
+
banner.style.border = '1px solid #faa59f';
|
| 22 |
+
banner.style.borderRadius = '8px';
|
| 23 |
+
banner.style.padding = '12px 16px';
|
| 24 |
+
// Added margin to ensure it doesn't overlap text awkwardly
|
| 25 |
+
banner.style.margin = '16px auto';
|
| 26 |
+
banner.style.fontFamily = '"Google Sans", Roboto, Arial, sans-serif';
|
| 27 |
+
banner.style.fontSize = '14px';
|
| 28 |
+
banner.style.fontWeight = '500';
|
| 29 |
+
banner.style.lineHeight = '20px';
|
| 30 |
+
banner.style.display = 'flex';
|
| 31 |
+
banner.style.alignItems = 'center';
|
| 32 |
+
banner.style.boxShadow = '0 1px 2px 0 rgba(60,64,67,0.3), 0 1px 3px 1px rgba(60,64,67,0.15)';
|
| 33 |
+
|
| 34 |
+
// Add SVG Icon for warning
|
| 35 |
+
const iconSvg = `
|
| 36 |
+
<svg focusable="false" width="24" height="24" viewBox="0 0 24 24" style="fill: #c5221f; margin-right: 16px; flex-shrink: 0;">
|
| 37 |
+
<path d="M1 21h22L12 2 1 21zm12-3h-2v-2h2v2zm0-4h-2v-4h2v4z"></path>
|
| 38 |
+
</svg>
|
| 39 |
+
`;
|
| 40 |
+
|
| 41 |
+
const textContent = document.createElement('span');
|
| 42 |
+
textContent.innerText = message || 'π¨ PhishGuard Warning: This email contains suspicious links.';
|
| 43 |
+
|
| 44 |
+
// Construct the banner
|
| 45 |
+
banner.innerHTML = iconSvg;
|
| 46 |
+
banner.appendChild(textContent);
|
| 47 |
+
|
| 48 |
+
// Insert banner at the top of the email container.
|
| 49 |
+
// By inserting before the email container, we keep it visible at the top of the body.
|
| 50 |
+
if (emailContainer.parentNode) {
|
| 51 |
+
emailContainer.parentNode.insertBefore(banner, emailContainer);
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
// Helper function to extract all unique URLs from the email body
|
| 57 |
+
function extractUrlsFromBody(emailContainer) {
|
| 58 |
+
const links = emailContainer.querySelectorAll('a[href]');
|
| 59 |
+
const urls = new Set(); // Use a Set to store unique URLs
|
| 60 |
+
|
| 61 |
+
links.forEach(link => {
|
| 62 |
+
const href = link.href;
|
| 63 |
+
// Basic filter to ignore mailto: or javascript: links, and only keep http/https
|
| 64 |
+
if (href && (href.startsWith('http://') || href.startsWith('https://'))) {
|
| 65 |
+
urls.add(href);
|
| 66 |
+
}
|
| 67 |
+
});
|
| 68 |
+
|
| 69 |
+
return Array.from(urls);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
// Function to safely extract the sender's actual email address
|
| 73 |
+
function extractSenderEmail(emailContainer) {
|
| 74 |
+
// Gmail usually groups each email message into a container block.
|
| 75 |
+
// We traverse up to find a common parent containing both header and body.
|
| 76 |
+
// '.kv' or 'table' or '.adn' is often the parent wrapper for an individual message.
|
| 77 |
+
let messageWrapper = emailContainer.closest('.adn') || emailContainer.closest('.kv') || document;
|
| 78 |
+
|
| 79 |
+
// The sender element usually has the class '.gD' and contains the 'email' attribute.
|
| 80 |
+
const senderElement = messageWrapper.querySelector('.gD');
|
| 81 |
+
|
| 82 |
+
if (senderElement && senderElement.getAttribute('email')) {
|
| 83 |
+
return senderElement.getAttribute('email');
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// Fallback: Sometimes the email address is enclosed in brackets inside a '.go' element.
|
| 87 |
+
const fallbackSenderElement = messageWrapper.querySelector('.go');
|
| 88 |
+
if (fallbackSenderElement && fallbackSenderElement.innerText) {
|
| 89 |
+
// e.g. "<sender@example.com>" -> matched and extracted
|
| 90 |
+
const match = fallbackSenderElement.innerText.match(/<([^>]+)>/);
|
| 91 |
+
if (match && match[1]) {
|
| 92 |
+
return match[1];
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
return "Unknown Sender";
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// Function to handle newly opened emails
|
| 100 |
+
function handleEmailOpened(emailContainer) {
|
| 101 |
+
// Prevent re-scanning the same email element
|
| 102 |
+
if (emailContainer.dataset.pgScanned === "true") {
|
| 103 |
+
return;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
console.log("PhishGuard AI: New email thread opened. Extracting data...");
|
| 107 |
+
|
| 108 |
+
// Mark as scanned
|
| 109 |
+
emailContainer.dataset.pgScanned = "true";
|
| 110 |
+
|
| 111 |
+
// 1. Extract plain text content
|
| 112 |
+
const emailBodyText = emailContainer.innerText;
|
| 113 |
+
|
| 114 |
+
// 2. Extract embedded URLs
|
| 115 |
+
const urls = extractUrlsFromBody(emailContainer);
|
| 116 |
+
|
| 117 |
+
// 3. Extract sender email address
|
| 118 |
+
const sender = extractSenderEmail(emailContainer);
|
| 119 |
+
|
| 120 |
+
// 4. Extract subject line (.hP is a standard class for Gmail's main subject line)
|
| 121 |
+
// Sometimes the subject might be in a '.bog' element. We'll default to 'h2.hP'.
|
| 122 |
+
const subjectElement = document.querySelector('h2.hP') || document.querySelector('.bog');
|
| 123 |
+
const subject = subjectElement ? subjectElement.innerText.trim() : "No Subject Found";
|
| 124 |
+
|
| 125 |
+
// Package the extracted data into a JSON payload
|
| 126 |
+
const emailPayload = {
|
| 127 |
+
sender: sender,
|
| 128 |
+
subject: subject,
|
| 129 |
+
body: emailBodyText,
|
| 130 |
+
urls: urls,
|
| 131 |
+
timestamp: new Date().toISOString()
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
console.log("PhishGuard AI extracted payload:", emailPayload);
|
| 135 |
+
|
| 136 |
+
// Send background message to service worker
|
| 137 |
+
chrome.runtime.sendMessage(
|
| 138 |
+
{
|
| 139 |
+
action: "analyzeEmail",
|
| 140 |
+
data: emailPayload
|
| 141 |
+
},
|
| 142 |
+
(response) => {
|
| 143 |
+
if (chrome.runtime.lastError) {
|
| 144 |
+
console.error("PhishGuard AI: Error communicating with background script:", chrome.runtime.lastError);
|
| 145 |
+
return;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
console.log("PhishGuard AI Background Analysis Response:", response);
|
| 149 |
+
|
| 150 |
+
// Assume the background script returns `response.analysis` containing `probability` or `isPhishing` flag
|
| 151 |
+
const analysis = response && response.analysis ? response.analysis : {};
|
| 152 |
+
|
| 153 |
+
if (analysis.isPhishing === true || analysis.probability > 0.70) {
|
| 154 |
+
console.warn("PhishGuard AI: High risk email detected! Injecting banner...");
|
| 155 |
+
injectWarningBanner(
|
| 156 |
+
emailContainer,
|
| 157 |
+
'π¨ PhishGuard Warning: This email contains suspicious links and exhibits high-risk phishing behavior.'
|
| 158 |
+
);
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
);
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
// Set up a MutationObserver to watch for DOM changes
|
| 165 |
+
// This effectively detects when Gmail dynamically loads an individual email view into the DOM
|
| 166 |
+
const observer = new MutationObserver((mutationsList) => {
|
| 167 |
+
for (const mutation of mutationsList) {
|
| 168 |
+
if (mutation.type === 'childList') {
|
| 169 |
+
mutation.addedNodes.forEach(node => {
|
| 170 |
+
if (node.nodeType === Node.ELEMENT_NODE) {
|
| 171 |
+
// Check if the added node itself is the email body container
|
| 172 |
+
if (node.matches && node.matches(EMAIL_BODY_SELECTOR)) {
|
| 173 |
+
handleEmailOpened(node);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// Also search securely within the added node structure for the email body
|
| 177 |
+
if (node.querySelectorAll) {
|
| 178 |
+
const emailBodies = node.querySelectorAll(EMAIL_BODY_SELECTOR);
|
| 179 |
+
emailBodies.forEach(body => handleEmailOpened(body));
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
});
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
});
|
| 186 |
+
|
| 187 |
+
// Start observing the document body for deeper added nodes (like when navigating between emails)
|
| 188 |
+
observer.observe(document.body, {
|
| 189 |
+
childList: true,
|
| 190 |
+
subtree: true
|
| 191 |
+
});
|
| 192 |
+
|
| 193 |
+
console.log("PhishGuard AI: MutationObserver is listening for email thread opens.");
|
gnn_inference.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - gnn/gnn_inference.py
|
| 3 |
+
# GNN inference wrapper for main.py.
|
| 4 |
+
# Loads model once at startup, reuses for every request.
|
| 5 |
+
# Supports: predict, hot-reload, incremental_update.
|
| 6 |
+
# ============================================================
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import random
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("phishguard.gnn.inference")
|
| 20 |
+
|
| 21 |
+
# Add parent paths
|
| 22 |
+
_GNN_DIR = Path(__file__).parent
|
| 23 |
+
_BACKEND_DIR = _GNN_DIR.parent
|
| 24 |
+
sys.path.insert(0, str(_GNN_DIR))
|
| 25 |
+
sys.path.insert(0, str(_BACKEND_DIR))
|
| 26 |
+
|
| 27 |
+
from domain_graph_builder import DomainGraphBuilder
|
| 28 |
+
from gnn_model import load_gnn_model, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM
|
| 29 |
+
|
| 30 |
+
if PYGEOM_AVAILABLE:
|
| 31 |
+
from gnn_model import PhishGNN
|
| 32 |
+
|
| 33 |
+
MODEL_PATH = _GNN_DIR / "gnn_weights.pt"
|
| 34 |
+
REPLAY_BUFFER_PATH = _BACKEND_DIR / "data" / "gnn_replay_buffer.pt"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GNNInference:
|
| 38 |
+
"""
|
| 39 |
+
GNN inference wrapper with hot-reload and incremental update support.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, weights_path: Optional[Path] = None) -> None:
|
| 43 |
+
self._weights_path = weights_path or MODEL_PATH
|
| 44 |
+
self._model: Optional[torch.nn.Module] = None
|
| 45 |
+
self._builder = DomainGraphBuilder()
|
| 46 |
+
self._loaded = False
|
| 47 |
+
|
| 48 |
+
def load(self, weights_path: Optional[Path] = None) -> bool:
|
| 49 |
+
"""Load GNN model from weights file."""
|
| 50 |
+
path = weights_path or self._weights_path
|
| 51 |
+
self._model = load_gnn_model(str(path) if path.exists() else None)
|
| 52 |
+
self._loaded = self._model is not None
|
| 53 |
+
if self._loaded:
|
| 54 |
+
logger.info(f"GNN model loaded from {path}")
|
| 55 |
+
return self._loaded
|
| 56 |
+
|
| 57 |
+
def predict(self, url: str, related_urls: Optional[List[str]] = None) -> float:
|
| 58 |
+
"""
|
| 59 |
+
Predict phishing probability for a URL.
|
| 60 |
+
Returns P_gnn β [0,1].
|
| 61 |
+
Falls back to MLP if model unavailable or graph too small.
|
| 62 |
+
"""
|
| 63 |
+
if not self._loaded:
|
| 64 |
+
self.load()
|
| 65 |
+
|
| 66 |
+
if self._model is None:
|
| 67 |
+
return 0.5 # Neutral when model unavailable
|
| 68 |
+
|
| 69 |
+
urls = [url] + (related_urls or [])
|
| 70 |
+
|
| 71 |
+
# Single URL β MLP fallback path
|
| 72 |
+
if len(urls) == 1:
|
| 73 |
+
graph = self._builder.build_single_node_graph(url)
|
| 74 |
+
else:
|
| 75 |
+
graph = self._builder.build_graph(urls)
|
| 76 |
+
|
| 77 |
+
x = torch.tensor(graph["features"], dtype=torch.float)
|
| 78 |
+
|
| 79 |
+
edges = graph["edges"]
|
| 80 |
+
if edges and len(edges) > 0:
|
| 81 |
+
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
|
| 82 |
+
else:
|
| 83 |
+
n = x.size(0)
|
| 84 |
+
edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1)
|
| 85 |
+
|
| 86 |
+
prob = self._model.predict_proba(x, edge_index)
|
| 87 |
+
return round(float(prob), 4)
|
| 88 |
+
|
| 89 |
+
def reload(self, weights_path: Optional[Path] = None) -> bool:
|
| 90 |
+
"""Hot-reload model with new weights (no server restart needed)."""
|
| 91 |
+
path = weights_path or self._weights_path
|
| 92 |
+
new_model = load_gnn_model(str(path))
|
| 93 |
+
if new_model is not None:
|
| 94 |
+
self._model = new_model
|
| 95 |
+
self._loaded = True
|
| 96 |
+
logger.info(f"GNN model hot-reloaded from {path}")
|
| 97 |
+
return True
|
| 98 |
+
logger.warning(f"GNN hot-reload failed from {path}")
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
def incremental_update(
|
| 102 |
+
self,
|
| 103 |
+
samples: List[Tuple[str, int]],
|
| 104 |
+
replay_buffer_path: Optional[Path] = None,
|
| 105 |
+
lr: float = 5e-4,
|
| 106 |
+
epochs: int = 5,
|
| 107 |
+
) -> Optional[float]:
|
| 108 |
+
"""
|
| 109 |
+
Incremental update on feedback samples + replay buffer.
|
| 110 |
+
Returns accuracy_delta or None if failed.
|
| 111 |
+
|
| 112 |
+
samples: list of (url, label) where label is 0 or 1
|
| 113 |
+
"""
|
| 114 |
+
if self._model is None:
|
| 115 |
+
logger.warning("GNN not loaded, cannot incrementally update")
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
if len(samples) < 5:
|
| 119 |
+
logger.warning(f"Too few samples ({len(samples)}) for GNN update")
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
import torch.nn.functional as F
|
| 124 |
+
|
| 125 |
+
device = torch.device("cpu")
|
| 126 |
+
model = self._model.to(device)
|
| 127 |
+
builder = DomainGraphBuilder()
|
| 128 |
+
|
| 129 |
+
# Build graphs from new feedback
|
| 130 |
+
new_graphs = []
|
| 131 |
+
CHUNK = 4
|
| 132 |
+
phish = [url for url, label in samples if label == 1]
|
| 133 |
+
legit = [url for url, label in samples if label == 0]
|
| 134 |
+
|
| 135 |
+
for urls, label in [(phish, 1), (legit, 0)]:
|
| 136 |
+
for i in range(0, len(urls), CHUNK):
|
| 137 |
+
chunk = urls[i:i + CHUNK]
|
| 138 |
+
if not chunk:
|
| 139 |
+
continue
|
| 140 |
+
graph = builder.build_graph(chunk)
|
| 141 |
+
x = torch.tensor(graph["features"], dtype=torch.float)
|
| 142 |
+
edges = graph["edges"]
|
| 143 |
+
if edges:
|
| 144 |
+
ei = torch.tensor(edges, dtype=torch.long).t().contiguous()
|
| 145 |
+
else:
|
| 146 |
+
n = x.size(0)
|
| 147 |
+
ei = torch.arange(n).unsqueeze(0).repeat(2, 1)
|
| 148 |
+
new_graphs.append({
|
| 149 |
+
"x": x, "edge_index": ei,
|
| 150 |
+
"y": torch.tensor([float(label)]),
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
# Load replay buffer (20% mix)
|
| 154 |
+
buf_path = replay_buffer_path or REPLAY_BUFFER_PATH
|
| 155 |
+
replay_graphs = []
|
| 156 |
+
if buf_path.exists():
|
| 157 |
+
try:
|
| 158 |
+
all_replay = torch.load(buf_path, map_location="cpu", weights_only=False)
|
| 159 |
+
replay_count = max(1, len(all_replay) // 5) # 20%
|
| 160 |
+
replay_graphs = random.sample(all_replay, min(replay_count, len(all_replay)))
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.warning(f"Replay buffer load failed: {e}")
|
| 163 |
+
|
| 164 |
+
# Merge: 80% new + 20% replay
|
| 165 |
+
dataset = new_graphs + replay_graphs
|
| 166 |
+
random.shuffle(dataset)
|
| 167 |
+
|
| 168 |
+
if not dataset:
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
# Pre-update accuracy
|
| 172 |
+
model.eval()
|
| 173 |
+
pre_correct = 0
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
for item in dataset:
|
| 176 |
+
out = model(item["x"].to(device), item["edge_index"].to(device))
|
| 177 |
+
pred = 1 if out.squeeze().item() >= 0.5 else 0
|
| 178 |
+
pre_correct += int(pred == int(item["y"].item()))
|
| 179 |
+
pre_acc = pre_correct / len(dataset)
|
| 180 |
+
|
| 181 |
+
# Train
|
| 182 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
| 183 |
+
model.train()
|
| 184 |
+
|
| 185 |
+
for epoch in range(epochs):
|
| 186 |
+
random.shuffle(dataset)
|
| 187 |
+
total_loss = 0.0
|
| 188 |
+
for item in dataset:
|
| 189 |
+
x = item["x"].to(device)
|
| 190 |
+
ei = item["edge_index"].to(device)
|
| 191 |
+
y = item["y"].to(device)
|
| 192 |
+
optimizer.zero_grad()
|
| 193 |
+
out = model(x, ei)
|
| 194 |
+
loss = F.binary_cross_entropy(out.squeeze(), y.squeeze())
|
| 195 |
+
loss.backward()
|
| 196 |
+
optimizer.step()
|
| 197 |
+
total_loss += loss.item()
|
| 198 |
+
logger.info(f"GNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(dataset):.4f}")
|
| 199 |
+
|
| 200 |
+
# Post-update accuracy
|
| 201 |
+
model.eval()
|
| 202 |
+
post_correct = 0
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
for item in dataset:
|
| 205 |
+
out = model(item["x"].to(device), item["edge_index"].to(device))
|
| 206 |
+
pred = 1 if out.squeeze().item() >= 0.5 else 0
|
| 207 |
+
post_correct += int(pred == int(item["y"].item()))
|
| 208 |
+
post_acc = post_correct / len(dataset)
|
| 209 |
+
|
| 210 |
+
delta = post_acc - pre_acc
|
| 211 |
+
self._model = model
|
| 212 |
+
|
| 213 |
+
# Save weights
|
| 214 |
+
torch.save(model.state_dict(), self._weights_path)
|
| 215 |
+
logger.info(f"GNN incremental update: {pre_acc:.4f} β {post_acc:.4f} (Ξ={delta:+.4f})")
|
| 216 |
+
|
| 217 |
+
# Update replay buffer (rolling 500)
|
| 218 |
+
try:
|
| 219 |
+
existing = []
|
| 220 |
+
if buf_path.exists():
|
| 221 |
+
existing = torch.load(buf_path, map_location="cpu", weights_only=False)
|
| 222 |
+
combined = existing + new_graphs
|
| 223 |
+
if len(combined) > 500:
|
| 224 |
+
combined = combined[-500:]
|
| 225 |
+
buf_path.parent.mkdir(parents=True, exist_ok=True)
|
| 226 |
+
torch.save(combined, buf_path)
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.warning(f"Replay buffer update failed: {e}")
|
| 229 |
+
|
| 230 |
+
return round(delta, 4)
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.error(f"GNN incremental update failed: {e}")
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def is_loaded(self) -> bool:
|
| 238 |
+
return self._loaded
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ββ Legacy compatibility functions βββββββββββββββββββββββββββββββββββ
|
| 242 |
+
_inference = GNNInference()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def analyze_url_with_gnn(url: str, related_urls: list = None) -> dict:
|
| 246 |
+
"""Legacy wrapper for backward compatibility."""
|
| 247 |
+
if not _inference.is_loaded:
|
| 248 |
+
_inference.load()
|
| 249 |
+
|
| 250 |
+
if not _inference.is_loaded:
|
| 251 |
+
return {
|
| 252 |
+
"gnn_phish_prob": None,
|
| 253 |
+
"tier3_status": "model_not_loaded",
|
| 254 |
+
"node_count": 0,
|
| 255 |
+
"edge_count": 0,
|
| 256 |
+
"graph_suspicious": False,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
prob = _inference.predict(url, related_urls)
|
| 260 |
+
return {
|
| 261 |
+
"gnn_phish_prob": prob,
|
| 262 |
+
"node_count": 1 + len(related_urls or []),
|
| 263 |
+
"edge_count": 0,
|
| 264 |
+
"graph_suspicious": prob > 0.6,
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def reload_model(new_weights_path: str = None) -> bool:
|
| 269 |
+
path = Path(new_weights_path) if new_weights_path else None
|
| 270 |
+
return _inference.reload(path)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def is_model_loaded() -> bool:
|
| 274 |
+
return _inference.is_loaded
|
gnn_model.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - gnn/gnn_model.py
|
| 3 |
+
# GNN + MLP model definitions for phishing graph classification.
|
| 4 |
+
#
|
| 5 |
+
# PhishGNN: 3-layer GCN with global_mean_pool β Linear β Sigmoid
|
| 6 |
+
# GCNConv(12β64) β ReLU β GCNConv(64β32) β ReLU β
|
| 7 |
+
# GCNConv(32β16) β global_mean_pool β Linear(16β1) β Sigmoid
|
| 8 |
+
#
|
| 9 |
+
# PhishMLP: Fallback for single URL or when torch_geometric unavailable
|
| 10 |
+
# Linear(12β64) β ReLU β Dropout(0.3) β Linear(64β1) β Sigmoid
|
| 11 |
+
# ============================================================
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("phishguard.gnn.model")
|
| 24 |
+
|
| 25 |
+
INPUT_DIM: int = 12 # 12-dim node features
|
| 26 |
+
HIDDEN_DIM: int = 64
|
| 27 |
+
OUTPUT_DIM: int = 1 # binary: sigmoid output
|
| 28 |
+
|
| 29 |
+
# ββ Try importing PyTorch Geometric ββββββββββββββββββββββββββββββββββ
|
| 30 |
+
PYGEOM_AVAILABLE: bool = False
|
| 31 |
+
try:
|
| 32 |
+
from torch_geometric.nn import GCNConv, global_mean_pool
|
| 33 |
+
PYGEOM_AVAILABLE = True
|
| 34 |
+
logger.info("PyTorch Geometric found β using full GCN model")
|
| 35 |
+
except ImportError:
|
| 36 |
+
PYGEOM_AVAILABLE = False
|
| 37 |
+
logger.info("PyTorch Geometric not found β using MLP fallback")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ββ PhishGNN: Full 3-layer Graph Convolutional Network βββββββββββββββ
|
| 41 |
+
if PYGEOM_AVAILABLE:
|
| 42 |
+
class PhishGNN(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
3-layer GCN for graph-level phishing classification.
|
| 45 |
+
Architecture from spec:
|
| 46 |
+
GCNConv(12β64) β ReLU β GCNConv(64β32) β ReLU β
|
| 47 |
+
GCNConv(32β16) β global_mean_pool β Linear(16β1) β Sigmoid
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
in_channels: int = INPUT_DIM,
|
| 53 |
+
hidden: int = HIDDEN_DIM,
|
| 54 |
+
out_channels: int = OUTPUT_DIM,
|
| 55 |
+
) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.conv1 = GCNConv(in_channels, hidden) # 12 β 64
|
| 58 |
+
self.conv2 = GCNConv(hidden, hidden // 2) # 64 β 32
|
| 59 |
+
self.conv3 = GCNConv(hidden // 2, hidden // 4) # 32 β 16
|
| 60 |
+
self.fc = nn.Linear(hidden // 4, out_channels) # 16 β 1
|
| 61 |
+
|
| 62 |
+
def forward(
|
| 63 |
+
self,
|
| 64 |
+
x: torch.Tensor,
|
| 65 |
+
edge_index: torch.Tensor,
|
| 66 |
+
batch: Optional[torch.Tensor] = None,
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
# Handle empty edge_index
|
| 69 |
+
if edge_index.numel() == 0:
|
| 70 |
+
edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device)
|
| 71 |
+
|
| 72 |
+
x = F.relu(self.conv1(x, edge_index))
|
| 73 |
+
x = F.relu(self.conv2(x, edge_index))
|
| 74 |
+
x = F.relu(self.conv3(x, edge_index))
|
| 75 |
+
|
| 76 |
+
if batch is None:
|
| 77 |
+
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
|
| 78 |
+
|
| 79 |
+
x = global_mean_pool(x, batch) # (batch_size, 16)
|
| 80 |
+
x = self.fc(x) # (batch_size, 1)
|
| 81 |
+
return torch.sigmoid(x) # [0, 1]
|
| 82 |
+
|
| 83 |
+
def predict_proba(
|
| 84 |
+
self,
|
| 85 |
+
x: torch.Tensor,
|
| 86 |
+
edge_index: torch.Tensor,
|
| 87 |
+
batch: Optional[torch.Tensor] = None,
|
| 88 |
+
) -> float:
|
| 89 |
+
"""Return P_gnn β [0,1] β probability of phishing."""
|
| 90 |
+
self.eval()
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
output = self.forward(x, edge_index, batch)
|
| 93 |
+
return output.squeeze().item()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ββ PhishMLP: Fallback for single URL or no torch_geometric ββββββββββ
|
| 97 |
+
class PhishMLP(nn.Module):
|
| 98 |
+
"""
|
| 99 |
+
MLP fallback for phishing classification.
|
| 100 |
+
Used when torch_geometric is unavailable or graph has < 2 nodes.
|
| 101 |
+
Architecture: Linear(12β64) β ReLU β Dropout(0.3) β Linear(64β1) β Sigmoid
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, in_channels: int = INPUT_DIM) -> None:
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.net = nn.Sequential(
|
| 107 |
+
nn.Linear(in_channels, 64),
|
| 108 |
+
nn.ReLU(),
|
| 109 |
+
nn.Dropout(0.3),
|
| 110 |
+
nn.Linear(64, 1),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
x: torch.Tensor,
|
| 116 |
+
edge_index: Optional[torch.Tensor] = None,
|
| 117 |
+
batch: Optional[torch.Tensor] = None,
|
| 118 |
+
) -> torch.Tensor:
|
| 119 |
+
# Pool all node features to single vector via mean
|
| 120 |
+
if x.dim() == 2 and x.size(0) > 1:
|
| 121 |
+
x = x.mean(dim=0, keepdim=True)
|
| 122 |
+
elif x.dim() == 1:
|
| 123 |
+
x = x.unsqueeze(0)
|
| 124 |
+
out = self.net(x)
|
| 125 |
+
return torch.sigmoid(out)
|
| 126 |
+
|
| 127 |
+
def predict_proba(
|
| 128 |
+
self,
|
| 129 |
+
x: torch.Tensor,
|
| 130 |
+
edge_index: Optional[torch.Tensor] = None,
|
| 131 |
+
batch: Optional[torch.Tensor] = None,
|
| 132 |
+
) -> float:
|
| 133 |
+
"""Return P_gnn β [0,1] β probability of phishing."""
|
| 134 |
+
self.eval()
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
output = self.forward(x, edge_index, batch)
|
| 137 |
+
return output.squeeze().item()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ββ Model loading utility ββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
def load_gnn_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
|
| 142 |
+
"""
|
| 143 |
+
Load GNN or MLP model with optional trained weights.
|
| 144 |
+
Returns model in eval mode, or None if creation fails.
|
| 145 |
+
"""
|
| 146 |
+
model: Optional[nn.Module] = None
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"GNN model creation failed: {e}")
|
| 152 |
+
try:
|
| 153 |
+
model = PhishMLP()
|
| 154 |
+
except Exception as e2:
|
| 155 |
+
logger.error(f"MLP fallback creation also failed: {e2}")
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
if model_path and os.path.exists(model_path):
|
| 159 |
+
try:
|
| 160 |
+
state = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 161 |
+
model.load_state_dict(state)
|
| 162 |
+
logger.info(f"GNN weights loaded from {model_path}")
|
| 163 |
+
except RuntimeError as e:
|
| 164 |
+
logger.warning(f"GNN weights mismatch (architecture changed?): {e}")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.warning(f"GNN weight load failed: {e}")
|
| 167 |
+
elif model_path:
|
| 168 |
+
logger.info(f"GNN weights file not found: {model_path}")
|
| 169 |
+
else:
|
| 170 |
+
logger.info("No GNN weights path β using untrained model")
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
model.eval()
|
| 174 |
+
except Exception as e:
|
| 175 |
+
logger.error(f"GNN eval() failed: {e}")
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
return model
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Legacy alias
|
| 182 |
+
def load_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
|
| 183 |
+
return load_gnn_model(model_path)
|
keep_alive.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - keep_alive.py
|
| 3 |
+
# Keeps your Render.com server awake 24/7.
|
| 4 |
+
#
|
| 5 |
+
# From architecture doc 5.3:
|
| 6 |
+
# Render free tier sleeps after 15 min of inactivity.
|
| 7 |
+
# This pings GET /health every 14 min to prevent that.
|
| 8 |
+
#
|
| 9 |
+
# WHERE TO RUN:
|
| 10 |
+
# - On a second laptop / old computer
|
| 11 |
+
# - On your phone using Termux (free Android app)
|
| 12 |
+
# - On a friend's computer
|
| 13 |
+
# - On your own laptop in a separate terminal (less ideal)
|
| 14 |
+
#
|
| 15 |
+
# HOW TO RUN:
|
| 16 |
+
# python keep_alive.py
|
| 17 |
+
# (keep this terminal window open β never close it)
|
| 18 |
+
# ============================================================
|
| 19 |
+
|
| 20 |
+
import time
|
| 21 |
+
import requests
|
| 22 |
+
import datetime
|
| 23 |
+
|
| 24 |
+
# !! CHANGE THIS to your actual Render URL !!
|
| 25 |
+
API_URL = "https://YOUR-APP-NAME.onrender.com/health"
|
| 26 |
+
|
| 27 |
+
INTERVAL = 14 * 60 # 14 minutes in seconds
|
| 28 |
+
|
| 29 |
+
print("=" * 50)
|
| 30 |
+
print("PhishGuard Keep-Alive Script")
|
| 31 |
+
print("=" * 50)
|
| 32 |
+
print(f"Pinging: {API_URL}")
|
| 33 |
+
print(f"Every: 14 minutes")
|
| 34 |
+
print(f"Started: {datetime.datetime.now():%Y-%m-%d %H:%M:%S}")
|
| 35 |
+
print("\nDO NOT close this window!")
|
| 36 |
+
print("Press Ctrl+C to stop.\n")
|
| 37 |
+
|
| 38 |
+
ping_count = 0
|
| 39 |
+
while True:
|
| 40 |
+
try:
|
| 41 |
+
r = requests.get(API_URL, timeout=15)
|
| 42 |
+
ping_count += 1
|
| 43 |
+
status = "OK" if r.status_code == 200 else f"ERROR {r.status_code}"
|
| 44 |
+
print(f"[{datetime.datetime.now():%H:%M:%S}] Ping #{ping_count} β {status}")
|
| 45 |
+
except requests.exceptions.ConnectionError:
|
| 46 |
+
print(f"[{datetime.datetime.now():%H:%M:%S}] Connection failed β server might be waking up...")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"[{datetime.datetime.now():%H:%M:%S}] Error: {e}")
|
| 49 |
+
|
| 50 |
+
time.sleep(INTERVAL)
|
main.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - main.py
|
| 3 |
+
# FastAPI orchestrator β Full 4-tier phishing detection pipeline
|
| 4 |
+
# with feedback-driven incremental retraining.
|
| 5 |
+
#
|
| 6 |
+
# Endpoints:
|
| 7 |
+
# POST /analyze β 4-tier URL phishing analysis
|
| 8 |
+
# POST /analyze/email β BERT-only email body analysis
|
| 9 |
+
# POST /retrain β Incremental model retraining
|
| 10 |
+
# GET /model_version β Current model version info
|
| 11 |
+
# GET /health β All model load statuses
|
| 12 |
+
#
|
| 13 |
+
# Architecture:
|
| 14 |
+
# Tier 1: Whitelist O(1) β SAFE exit (~55% traffic)
|
| 15 |
+
# Tier 2: Heuristic 15 signals β BLOCK if >= 80 (~15% blocked)
|
| 16 |
+
# Tier 3: BERT+GNN parallel β BLOCK/SAFE/escalate (~15% exits)
|
| 17 |
+
# Tier 4: CNN visual + brand hash β BLOCK/SAFE (~15% borderline)
|
| 18 |
+
# ============================================================
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
import asyncio
|
| 25 |
+
import time
|
| 26 |
+
import hashlib
|
| 27 |
+
import logging
|
| 28 |
+
import logging.handlers
|
| 29 |
+
from collections import OrderedDict
|
| 30 |
+
from contextlib import asynccontextmanager
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import List, Optional
|
| 33 |
+
|
| 34 |
+
from fastapi import FastAPI
|
| 35 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 36 |
+
from pydantic import BaseModel
|
| 37 |
+
|
| 38 |
+
# ββ Path setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
BASE_DIR = Path(__file__).parent
|
| 40 |
+
for sub_dir in ["gnn", "cnn"]:
|
| 41 |
+
sub_path = BASE_DIR / sub_dir
|
| 42 |
+
if sub_path.is_dir():
|
| 43 |
+
sys.path.insert(0, str(sub_path))
|
| 44 |
+
|
| 45 |
+
# ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
log_dir = BASE_DIR / "logs"
|
| 47 |
+
log_dir.mkdir(exist_ok=True)
|
| 48 |
+
|
| 49 |
+
_handler = logging.handlers.RotatingFileHandler(
|
| 50 |
+
log_dir / "phishguard.log",
|
| 51 |
+
maxBytes=5 * 1024 * 1024,
|
| 52 |
+
backupCount=3,
|
| 53 |
+
encoding="utf-8",
|
| 54 |
+
)
|
| 55 |
+
_handler.setFormatter(logging.Formatter(
|
| 56 |
+
"%(asctime)s | %(levelname)-7s | %(name)s | %(message)s",
|
| 57 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 58 |
+
))
|
| 59 |
+
|
| 60 |
+
logger = logging.getLogger("phishguard")
|
| 61 |
+
logger.setLevel(logging.INFO)
|
| 62 |
+
logger.addHandler(_handler)
|
| 63 |
+
logger.addHandler(logging.StreamHandler())
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ββ Import project modules βββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
from url_heuristics import HeuristicScorer, HeuristicResult
|
| 68 |
+
from bert_analyzer import BERTPhishingClassifier
|
| 69 |
+
|
| 70 |
+
# GNN imports
|
| 71 |
+
GNN_AVAILABLE = False
|
| 72 |
+
gnn_inference = None
|
| 73 |
+
try:
|
| 74 |
+
from gnn.gnn_inference import GNNInference
|
| 75 |
+
GNN_AVAILABLE = True
|
| 76 |
+
except ImportError:
|
| 77 |
+
try:
|
| 78 |
+
from gnn_inference import GNNInference
|
| 79 |
+
GNN_AVAILABLE = True
|
| 80 |
+
except ImportError:
|
| 81 |
+
logger.warning("GNN module not available")
|
| 82 |
+
|
| 83 |
+
# CNN imports
|
| 84 |
+
CNN_AVAILABLE = False
|
| 85 |
+
cnn_inference = None
|
| 86 |
+
brand_detector = None
|
| 87 |
+
try:
|
| 88 |
+
from cnn.cnn_inference import CNNInference
|
| 89 |
+
from cnn.screenshot_hasher import BrandHashDetector
|
| 90 |
+
from cnn.cnn_model import preprocess_screenshot
|
| 91 |
+
CNN_AVAILABLE = True
|
| 92 |
+
except ImportError:
|
| 93 |
+
try:
|
| 94 |
+
from cnn_inference import CNNInference
|
| 95 |
+
from screenshot_hasher import BrandHashDetector
|
| 96 |
+
from cnn_model import preprocess_screenshot
|
| 97 |
+
CNN_AVAILABLE = True
|
| 98 |
+
except ImportError:
|
| 99 |
+
logger.warning("CNN module not available")
|
| 100 |
+
|
| 101 |
+
from tier3_bert_gnn import Tier3Ensemble
|
| 102 |
+
from retraining_service import RetrainingService, FeedbackRecord, RetrainResult
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ββ Whitelist (Tier 1) ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
WHITELIST: set[str] = {
|
| 107 |
+
"google.com", "youtube.com", "facebook.com", "amazon.com", "wikipedia.org",
|
| 108 |
+
"twitter.com", "instagram.com", "linkedin.com", "microsoft.com", "apple.com",
|
| 109 |
+
"github.com", "stackoverflow.com", "reddit.com", "netflix.com", "paypal.com",
|
| 110 |
+
"bankofamerica.com", "chase.com", "wellsfargo.com", "yahoo.com", "bing.com",
|
| 111 |
+
"outlook.com", "office.com", "live.com", "adobe.com", "dropbox.com",
|
| 112 |
+
"zoom.us", "slack.com", "spotify.com", "twitch.tv", "ebay.com",
|
| 113 |
+
"walmart.com", "target.com", "bestbuy.com", "airbnb.com",
|
| 114 |
+
"x.com", "tiktok.com", "pinterest.com", "quora.com", "medium.com",
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_root_domain(url: str) -> str:
|
| 119 |
+
"""Extract root domain from a URL."""
|
| 120 |
+
from urllib.parse import urlparse
|
| 121 |
+
try:
|
| 122 |
+
host = urlparse(url).hostname or ""
|
| 123 |
+
host = host.replace("www.", "")
|
| 124 |
+
parts = host.split(".")
|
| 125 |
+
return ".".join(parts[-2:]) if len(parts) >= 2 else host
|
| 126 |
+
except Exception:
|
| 127 |
+
return ""
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ββ URL Cache (LRU, 30-min TTL) ββββββββββββββββββββββββββββββββββββββ
|
| 131 |
+
CACHE_TTL = 30 * 60
|
| 132 |
+
CACHE_MAX = 500
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class URLCache:
|
| 136 |
+
def __init__(self, maxsize: int = CACHE_MAX, ttl: int = CACHE_TTL) -> None:
|
| 137 |
+
self._cache: OrderedDict = OrderedDict()
|
| 138 |
+
self._maxsize = maxsize
|
| 139 |
+
self._ttl = ttl
|
| 140 |
+
|
| 141 |
+
def get(self, url: str) -> Optional[dict]:
|
| 142 |
+
if url in self._cache:
|
| 143 |
+
entry = self._cache[url]
|
| 144 |
+
if time.time() - entry["ts"] < self._ttl:
|
| 145 |
+
self._cache.move_to_end(url)
|
| 146 |
+
return entry["result"]
|
| 147 |
+
else:
|
| 148 |
+
del self._cache[url]
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
def set(self, url: str, result: dict) -> None:
|
| 152 |
+
self._cache[url] = {"result": result, "ts": time.time()}
|
| 153 |
+
self._cache.move_to_end(url)
|
| 154 |
+
if len(self._cache) > self._maxsize:
|
| 155 |
+
self._cache.popitem(last=False)
|
| 156 |
+
|
| 157 |
+
def clear(self) -> None:
|
| 158 |
+
self._cache.clear()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
_url_cache = URLCache()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ββ Request/Response Models βββββββββββββββββββββββββββββββββββββββββββ
|
| 165 |
+
class AnalyzeRequest(BaseModel):
|
| 166 |
+
url: str
|
| 167 |
+
heuristic_score: float = 0.0
|
| 168 |
+
page_title: str = ""
|
| 169 |
+
page_snippet: str = ""
|
| 170 |
+
related_urls: list = []
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class EmailRequest(BaseModel):
|
| 174 |
+
sender: str
|
| 175 |
+
subject: str = ""
|
| 176 |
+
body: str = ""
|
| 177 |
+
urls: list = []
|
| 178 |
+
timestamp: str = ""
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class FeedbackSample(BaseModel):
|
| 182 |
+
url: str
|
| 183 |
+
verdict: str = ""
|
| 184 |
+
confidence: float = 0.0
|
| 185 |
+
tier_used: int = 0
|
| 186 |
+
heuristic_score: int = 0
|
| 187 |
+
signals: list = []
|
| 188 |
+
user_feedback: Optional[str] = None
|
| 189 |
+
timestamp: str = ""
|
| 190 |
+
feedback_ts: Optional[str] = None
|
| 191 |
+
url_hash: str = ""
|
| 192 |
+
session_id: str = ""
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class RetrainRequest(BaseModel):
|
| 196 |
+
samples: List[FeedbackSample]
|
| 197 |
+
trigger: str = "count"
|
| 198 |
+
session_id: str = ""
|
| 199 |
+
extension_version: str = ""
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ββ Global state ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 203 |
+
_scorer: Optional[HeuristicScorer] = None
|
| 204 |
+
_bert: Optional[BERTPhishingClassifier] = None
|
| 205 |
+
_gnn: Optional[GNNInference] = None
|
| 206 |
+
_cnn: Optional[CNNInference] = None
|
| 207 |
+
_brand: Optional[BrandHashDetector] = None
|
| 208 |
+
_tier3: Optional[Tier3Ensemble] = None
|
| 209 |
+
_retrain_service: Optional[RetrainingService] = None
|
| 210 |
+
_retrain_lock = asyncio.Lock()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ββ Lifespan (startup/shutdown) βββββββββββββββββββββββββββββββββββββββ
|
| 214 |
+
@asynccontextmanager
|
| 215 |
+
async def lifespan(app: FastAPI):
|
| 216 |
+
"""Load all models at startup, clean up at shutdown."""
|
| 217 |
+
global _scorer, _bert, _gnn, _cnn, _brand, _tier3, _retrain_service
|
| 218 |
+
|
| 219 |
+
logger.info("=== PhishGuard AI starting up ===")
|
| 220 |
+
|
| 221 |
+
# Tier 2: Heuristic Scorer
|
| 222 |
+
_scorer = HeuristicScorer()
|
| 223 |
+
logger.info("β Tier 2: HeuristicScorer initialized")
|
| 224 |
+
|
| 225 |
+
# Tier 3a: BERT
|
| 226 |
+
_bert = BERTPhishingClassifier()
|
| 227 |
+
logger.info("β Tier 3a: BERT classifier initialized (lazy-load)")
|
| 228 |
+
|
| 229 |
+
# Tier 3b: GNN
|
| 230 |
+
if GNN_AVAILABLE:
|
| 231 |
+
_gnn = GNNInference()
|
| 232 |
+
_gnn.load()
|
| 233 |
+
logger.info(f"β Tier 3b: GNN loaded={_gnn.is_loaded}")
|
| 234 |
+
else:
|
| 235 |
+
_gnn = None
|
| 236 |
+
logger.warning("β Tier 3b: GNN not available")
|
| 237 |
+
|
| 238 |
+
# Tier 3 Ensemble
|
| 239 |
+
if _gnn:
|
| 240 |
+
_tier3 = Tier3Ensemble(_bert, _gnn)
|
| 241 |
+
logger.info("β Tier 3: Ensemble initialized")
|
| 242 |
+
else:
|
| 243 |
+
_tier3 = None
|
| 244 |
+
logger.warning("β Tier 3: Ensemble not available (GNN missing)")
|
| 245 |
+
|
| 246 |
+
# Tier 4: CNN + Brand Detection
|
| 247 |
+
if CNN_AVAILABLE:
|
| 248 |
+
_cnn = CNNInference()
|
| 249 |
+
_cnn.load()
|
| 250 |
+
_brand = BrandHashDetector()
|
| 251 |
+
logger.info(f"β Tier 4: CNN loaded={_cnn.is_loaded}, Brand hash DB loaded")
|
| 252 |
+
else:
|
| 253 |
+
_cnn = None
|
| 254 |
+
_brand = None
|
| 255 |
+
logger.warning("β Tier 4: CNN not available")
|
| 256 |
+
|
| 257 |
+
# Retraining Service
|
| 258 |
+
_retrain_service = RetrainingService(
|
| 259 |
+
bert_classifier=_bert,
|
| 260 |
+
gnn_inference=_gnn or GNNInference(),
|
| 261 |
+
cnn_inference=_cnn or (CNNInference() if CNN_AVAILABLE else None),
|
| 262 |
+
)
|
| 263 |
+
logger.info("β Retraining service initialized")
|
| 264 |
+
logger.info("=== PhishGuard AI ready ===")
|
| 265 |
+
|
| 266 |
+
yield
|
| 267 |
+
|
| 268 |
+
logger.info("=== PhishGuard AI shutting down ===")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ββ FastAPI App βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 272 |
+
app = FastAPI(
|
| 273 |
+
title="PhishGuard AI Backend",
|
| 274 |
+
version="3.0",
|
| 275 |
+
description="4-tier ML phishing detection with feedback-driven retraining",
|
| 276 |
+
lifespan=lifespan,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
app.add_middleware(
|
| 280 |
+
CORSMiddleware,
|
| 281 |
+
allow_origins=["*"],
|
| 282 |
+
allow_methods=["*"],
|
| 283 |
+
allow_headers=["*"],
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ββ POST /analyze β Full 4-tier pipeline ββββββββββββββββββββββββββββββ
|
| 288 |
+
@app.post("/analyze")
|
| 289 |
+
async def analyze_endpoint(req: AnalyzeRequest) -> dict:
|
| 290 |
+
"""
|
| 291 |
+
Analyze a URL through the 4-tier phishing detection pipeline.
|
| 292 |
+
|
| 293 |
+
Tier 1: Whitelist β SAFE
|
| 294 |
+
Tier 2: Heuristic β BLOCK if >= 80
|
| 295 |
+
Tier 3: BERT+GNN ensemble β BLOCK/SAFE/escalate
|
| 296 |
+
Tier 4: CNN visual + brand hash β BLOCK/SAFE
|
| 297 |
+
"""
|
| 298 |
+
url = req.url
|
| 299 |
+
details: dict = {}
|
| 300 |
+
|
| 301 |
+
# ββ TIER 1: Whitelist ββββββββββββββββββββββββββββββββββββββββ
|
| 302 |
+
root = get_root_domain(url)
|
| 303 |
+
if root in WHITELIST:
|
| 304 |
+
return {
|
| 305 |
+
"url": url,
|
| 306 |
+
"is_phishing": False,
|
| 307 |
+
"confidence": 0.0,
|
| 308 |
+
"method": "whitelist",
|
| 309 |
+
"status": "safe",
|
| 310 |
+
"tier": 1,
|
| 311 |
+
"heuristic_score": 0,
|
| 312 |
+
"signals": [],
|
| 313 |
+
"details": {"whitelisted_domain": root},
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
# ββ Cache check ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 317 |
+
cached = _url_cache.get(url)
|
| 318 |
+
if cached is not None:
|
| 319 |
+
return cached
|
| 320 |
+
|
| 321 |
+
# ββ TIER 2: Heuristic scoring ββββββββββββββββββββββββββββββββ
|
| 322 |
+
h_result: HeuristicResult = _scorer.score(url)
|
| 323 |
+
|
| 324 |
+
# Use the higher of server-side and browser-side heuristic scores
|
| 325 |
+
h_score = max(h_result.score, int(req.heuristic_score))
|
| 326 |
+
details["heuristic"] = {
|
| 327 |
+
"score": h_result.score,
|
| 328 |
+
"raw_score": h_result.raw_score,
|
| 329 |
+
"signals": h_result.signals,
|
| 330 |
+
"browser_score": int(req.heuristic_score),
|
| 331 |
+
"combined_score": h_score,
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
if h_score >= 80:
|
| 335 |
+
result = {
|
| 336 |
+
"url": url,
|
| 337 |
+
"is_phishing": True,
|
| 338 |
+
"confidence": h_score / 100.0,
|
| 339 |
+
"method": "heuristic",
|
| 340 |
+
"status": "blocked",
|
| 341 |
+
"tier": 2,
|
| 342 |
+
"heuristic_score": h_score,
|
| 343 |
+
"signals": h_result.signals,
|
| 344 |
+
"details": details,
|
| 345 |
+
}
|
| 346 |
+
_url_cache.set(url, result)
|
| 347 |
+
logger.info(f"Tier 2 BLOCK | url={url[:60]} | score={h_score}")
|
| 348 |
+
return result
|
| 349 |
+
|
| 350 |
+
# ββ TIER 3: BERT + GNN Ensemble ββββββββββββββββββββββββββββββ
|
| 351 |
+
if _tier3 is not None:
|
| 352 |
+
try:
|
| 353 |
+
p3 = await _tier3.predict(
|
| 354 |
+
url=url,
|
| 355 |
+
title=req.page_title,
|
| 356 |
+
snippet=req.page_snippet,
|
| 357 |
+
h_score=h_score,
|
| 358 |
+
)
|
| 359 |
+
details["tier3_score"] = p3
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logger.error(f"Tier 3 error: {e}")
|
| 362 |
+
p3 = h_score / 100.0 # fallback to heuristic
|
| 363 |
+
details["tier3_error"] = str(e)
|
| 364 |
+
else:
|
| 365 |
+
# Tier 3 unavailable β use BERT alone + heuristic
|
| 366 |
+
if _bert is not None:
|
| 367 |
+
loop = asyncio.get_event_loop()
|
| 368 |
+
try:
|
| 369 |
+
p_bert = await loop.run_in_executor(
|
| 370 |
+
None, _bert.predict, url, req.page_title, req.page_snippet,
|
| 371 |
+
)
|
| 372 |
+
except Exception:
|
| 373 |
+
p_bert = 0.5
|
| 374 |
+
h_norm = h_score / 100.0
|
| 375 |
+
p3 = 0.60 * p_bert + 0.40 * h_norm
|
| 376 |
+
else:
|
| 377 |
+
p3 = h_score / 100.0
|
| 378 |
+
details["tier3_score"] = p3
|
| 379 |
+
details["tier3_note"] = "ensemble_unavailable"
|
| 380 |
+
|
| 381 |
+
# Tier 3 decision
|
| 382 |
+
decision = Tier3Ensemble.decide(p3)
|
| 383 |
+
|
| 384 |
+
if decision == "block":
|
| 385 |
+
result = {
|
| 386 |
+
"url": url,
|
| 387 |
+
"is_phishing": True,
|
| 388 |
+
"confidence": round(p3, 4),
|
| 389 |
+
"method": "bert_gnn_ensemble",
|
| 390 |
+
"status": "blocked",
|
| 391 |
+
"tier": 3,
|
| 392 |
+
"heuristic_score": h_score,
|
| 393 |
+
"signals": h_result.signals,
|
| 394 |
+
"details": details,
|
| 395 |
+
}
|
| 396 |
+
_url_cache.set(url, result)
|
| 397 |
+
logger.info(f"Tier 3 BLOCK | url={url[:60]} | P3={p3:.4f}")
|
| 398 |
+
return result
|
| 399 |
+
|
| 400 |
+
if decision == "safe":
|
| 401 |
+
result = {
|
| 402 |
+
"url": url,
|
| 403 |
+
"is_phishing": False,
|
| 404 |
+
"confidence": round(p3, 4),
|
| 405 |
+
"method": "bert_gnn_ensemble",
|
| 406 |
+
"status": "safe",
|
| 407 |
+
"tier": 3,
|
| 408 |
+
"heuristic_score": h_score,
|
| 409 |
+
"signals": h_result.signals,
|
| 410 |
+
"details": details,
|
| 411 |
+
}
|
| 412 |
+
_url_cache.set(url, result)
|
| 413 |
+
logger.info(f"Tier 3 SAFE | url={url[:60]} | P3={p3:.4f}")
|
| 414 |
+
return result
|
| 415 |
+
|
| 416 |
+
# ββ TIER 4: CNN Visual + Brand Hash (borderline 0.40 β€ P3 < 0.85)
|
| 417 |
+
if _cnn is not None and _cnn.is_loaded:
|
| 418 |
+
try:
|
| 419 |
+
# Capture screenshot
|
| 420 |
+
screenshot_bytes = await _capture_screenshot_for_tier4(url)
|
| 421 |
+
|
| 422 |
+
if screenshot_bytes:
|
| 423 |
+
# CNN prediction
|
| 424 |
+
p_cnn = _cnn.predict(screenshot_bytes)
|
| 425 |
+
details["cnn_prob"] = round(p_cnn, 4)
|
| 426 |
+
|
| 427 |
+
# Brand hash check
|
| 428 |
+
brand_boost = 0.0
|
| 429 |
+
if _brand is not None:
|
| 430 |
+
is_impersonation, brand_name, brand_conf = _brand.detect(
|
| 431 |
+
screenshot_bytes, url
|
| 432 |
+
)
|
| 433 |
+
details["brand"] = {
|
| 434 |
+
"impersonation_detected": is_impersonation,
|
| 435 |
+
"brand": brand_name,
|
| 436 |
+
"confidence": round(brand_conf, 3),
|
| 437 |
+
}
|
| 438 |
+
if is_impersonation:
|
| 439 |
+
brand_boost = 0.25
|
| 440 |
+
|
| 441 |
+
# P_final = 0.55Β·P3 + 0.30Β·P_cnn + brand_boost
|
| 442 |
+
p_final = min((0.55 * p3) + (0.30 * p_cnn) + brand_boost, 1.0)
|
| 443 |
+
details["tier4_score"] = round(p_final, 4)
|
| 444 |
+
|
| 445 |
+
is_phishing = p_final >= 0.65
|
| 446 |
+
result = {
|
| 447 |
+
"url": url,
|
| 448 |
+
"is_phishing": is_phishing,
|
| 449 |
+
"confidence": round(p_final, 4),
|
| 450 |
+
"method": "full_ensemble_bert_gnn_cnn",
|
| 451 |
+
"status": "blocked" if is_phishing else "safe",
|
| 452 |
+
"tier": 4,
|
| 453 |
+
"heuristic_score": h_score,
|
| 454 |
+
"signals": h_result.signals,
|
| 455 |
+
"details": details,
|
| 456 |
+
}
|
| 457 |
+
_url_cache.set(url, result)
|
| 458 |
+
logger.info(f"Tier 4 {'BLOCK' if is_phishing else 'SAFE'} | url={url[:60]} | P_final={p_final:.4f}")
|
| 459 |
+
return result
|
| 460 |
+
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.error(f"Tier 4 error: {e}")
|
| 463 |
+
details["tier4_error"] = str(e)
|
| 464 |
+
|
| 465 |
+
# Tier 4 unavailable/failed β use Tier 3 score with conservative threshold
|
| 466 |
+
is_phishing = p3 >= 0.65
|
| 467 |
+
result = {
|
| 468 |
+
"url": url,
|
| 469 |
+
"is_phishing": is_phishing,
|
| 470 |
+
"confidence": round(p3, 4),
|
| 471 |
+
"method": "bert_gnn_ensemble",
|
| 472 |
+
"status": "blocked" if is_phishing else "safe",
|
| 473 |
+
"tier": 3,
|
| 474 |
+
"heuristic_score": h_score,
|
| 475 |
+
"signals": h_result.signals,
|
| 476 |
+
"details": details,
|
| 477 |
+
}
|
| 478 |
+
_url_cache.set(url, result)
|
| 479 |
+
logger.info(f"Tier 4 fallback β Tier 3 | url={url[:60]} | P3={p3:.4f}")
|
| 480 |
+
return result
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
async def _capture_screenshot_for_tier4(url: str) -> Optional[bytes]:
|
| 484 |
+
"""Capture screenshot for Tier 4 CNN analysis."""
|
| 485 |
+
try:
|
| 486 |
+
from playwright.async_api import async_playwright
|
| 487 |
+
|
| 488 |
+
async with async_playwright() as p:
|
| 489 |
+
browser = await p.chromium.launch(headless=True)
|
| 490 |
+
page = await browser.new_page(
|
| 491 |
+
viewport={"width": 1280, "height": 800},
|
| 492 |
+
user_agent=(
|
| 493 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 494 |
+
"AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36"
|
| 495 |
+
),
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# Block heavy resources
|
| 499 |
+
await page.route(
|
| 500 |
+
"**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,wav,mp3}",
|
| 501 |
+
lambda route: route.abort(),
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
await page.goto(url, wait_until="domcontentloaded", timeout=10000)
|
| 505 |
+
screenshot = await page.screenshot(type="png")
|
| 506 |
+
await browser.close()
|
| 507 |
+
return screenshot
|
| 508 |
+
|
| 509 |
+
except Exception as e:
|
| 510 |
+
logger.warning(f"Tier 4 screenshot failed: {e}")
|
| 511 |
+
return None
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# ββ POST /analyze/email βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 515 |
+
@app.post("/analyze/email")
|
| 516 |
+
async def analyze_email_endpoint(req: EmailRequest) -> dict:
|
| 517 |
+
"""BERT-only path for email body text analysis."""
|
| 518 |
+
# Sender whitelist check
|
| 519 |
+
sender_domain = req.sender.split("@")[-1].lower() if "@" in req.sender else ""
|
| 520 |
+
if sender_domain in WHITELIST:
|
| 521 |
+
return {
|
| 522 |
+
"status": "safe",
|
| 523 |
+
"analysis": {
|
| 524 |
+
"isPhishing": False,
|
| 525 |
+
"probability": 0.0,
|
| 526 |
+
"reason": "Trusted sender domain",
|
| 527 |
+
},
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
# Analyze embedded URLs
|
| 531 |
+
MAX_URLS = 3
|
| 532 |
+
urls_to_check = req.urls[:MAX_URLS]
|
| 533 |
+
|
| 534 |
+
if not urls_to_check:
|
| 535 |
+
# Text-only analysis
|
| 536 |
+
if _bert:
|
| 537 |
+
combined = f"{req.subject} {req.body}"
|
| 538 |
+
prob = _bert.predict(combined, req.subject, req.body)
|
| 539 |
+
is_phishing = prob > 0.6
|
| 540 |
+
return {
|
| 541 |
+
"status": "blocked" if is_phishing else "safe",
|
| 542 |
+
"analysis": {
|
| 543 |
+
"isPhishing": is_phishing,
|
| 544 |
+
"probability": prob,
|
| 545 |
+
"reason": "BERT text analysis (no URLs)",
|
| 546 |
+
},
|
| 547 |
+
}
|
| 548 |
+
return {
|
| 549 |
+
"status": "safe",
|
| 550 |
+
"analysis": {
|
| 551 |
+
"isPhishing": False,
|
| 552 |
+
"probability": 0.1,
|
| 553 |
+
"reason": "No URLs and no ML model available",
|
| 554 |
+
},
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
# Analyze URLs through the main pipeline
|
| 558 |
+
tasks = [
|
| 559 |
+
analyze_endpoint(AnalyzeRequest(url=u, page_title=req.subject))
|
| 560 |
+
for u in urls_to_check
|
| 561 |
+
]
|
| 562 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 563 |
+
|
| 564 |
+
max_prob = 0.0
|
| 565 |
+
phishing_detected = False
|
| 566 |
+
flagged_urls = []
|
| 567 |
+
|
| 568 |
+
for idx, r in enumerate(results):
|
| 569 |
+
if isinstance(r, Exception):
|
| 570 |
+
continue
|
| 571 |
+
prob = r.get("confidence", 0.0)
|
| 572 |
+
max_prob = max(max_prob, prob)
|
| 573 |
+
if r.get("is_phishing"):
|
| 574 |
+
phishing_detected = True
|
| 575 |
+
flagged_urls.append(r.get("url", urls_to_check[idx]))
|
| 576 |
+
|
| 577 |
+
return {
|
| 578 |
+
"status": "blocked" if phishing_detected else "safe",
|
| 579 |
+
"analysis": {
|
| 580 |
+
"isPhishing": phishing_detected,
|
| 581 |
+
"probability": max_prob,
|
| 582 |
+
"flagged_urls": flagged_urls,
|
| 583 |
+
"reason": "URL analysis via ML ensemble",
|
| 584 |
+
},
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
# ββ POST /retrain β Incremental retraining ββββββββββββββββββββββββββββ
|
| 589 |
+
@app.post("/retrain")
|
| 590 |
+
async def retrain_endpoint(req: RetrainRequest) -> dict:
|
| 591 |
+
"""
|
| 592 |
+
Receive labeled feedback and incrementally update all models.
|
| 593 |
+
Uses asyncio.Lock() to prevent concurrent retraining jobs.
|
| 594 |
+
Timeout: 600s max.
|
| 595 |
+
"""
|
| 596 |
+
if _retrain_service is None:
|
| 597 |
+
return {"status": "error", "message": "Retraining service not initialized"}
|
| 598 |
+
|
| 599 |
+
# Prevent concurrent retraining
|
| 600 |
+
if _retrain_lock.locked():
|
| 601 |
+
return {
|
| 602 |
+
"status": "skipped",
|
| 603 |
+
"message": "Retraining already in progress",
|
| 604 |
+
"models_updated": [],
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
async with _retrain_lock:
|
| 608 |
+
# Convert Pydantic models to FeedbackRecord dataclasses
|
| 609 |
+
records = [
|
| 610 |
+
FeedbackRecord(
|
| 611 |
+
url=s.url,
|
| 612 |
+
verdict=s.verdict,
|
| 613 |
+
confidence=s.confidence,
|
| 614 |
+
tier_used=s.tier_used,
|
| 615 |
+
heuristic_score=s.heuristic_score,
|
| 616 |
+
signals=s.signals,
|
| 617 |
+
user_feedback=s.user_feedback,
|
| 618 |
+
timestamp=s.timestamp,
|
| 619 |
+
feedback_ts=s.feedback_ts,
|
| 620 |
+
url_hash=s.url_hash,
|
| 621 |
+
session_id=s.session_id,
|
| 622 |
+
)
|
| 623 |
+
for s in req.samples
|
| 624 |
+
]
|
| 625 |
+
|
| 626 |
+
try:
|
| 627 |
+
result = await asyncio.wait_for(
|
| 628 |
+
_retrain_service.retrain(records),
|
| 629 |
+
timeout=600,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# Clear URL cache after retraining (stale results)
|
| 633 |
+
if result.status == "success":
|
| 634 |
+
_url_cache.clear()
|
| 635 |
+
|
| 636 |
+
return {
|
| 637 |
+
"status": result.status,
|
| 638 |
+
"models_updated": result.models_updated,
|
| 639 |
+
"samples_used": result.samples_used,
|
| 640 |
+
"duration_seconds": result.duration_seconds,
|
| 641 |
+
"accuracy_delta": result.accuracy_delta,
|
| 642 |
+
"next_retrain_hint": result.next_retrain_hint,
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
except asyncio.TimeoutError:
|
| 646 |
+
return {
|
| 647 |
+
"status": "error",
|
| 648 |
+
"message": "Retraining timed out (600s limit)",
|
| 649 |
+
}
|
| 650 |
+
except Exception as e:
|
| 651 |
+
logger.error(f"Retrain endpoint error: {e}")
|
| 652 |
+
return {
|
| 653 |
+
"status": "error",
|
| 654 |
+
"message": str(e),
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
# ββ GET /model_version ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 659 |
+
@app.get("/model_version")
|
| 660 |
+
async def model_version_endpoint() -> dict:
|
| 661 |
+
"""Return current model version info for extension polling."""
|
| 662 |
+
if _retrain_service:
|
| 663 |
+
return _retrain_service.get_version_info()
|
| 664 |
+
return {"version": 0, "updated_at": None, "accuracy": {}}
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
# ββ GET /health βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 668 |
+
@app.get("/health")
|
| 669 |
+
async def health_endpoint() -> dict:
|
| 670 |
+
"""Liveness probe with per-tier readiness and model statuses."""
|
| 671 |
+
return {
|
| 672 |
+
"status": "ok",
|
| 673 |
+
"version": "3.0",
|
| 674 |
+
"tier1": True,
|
| 675 |
+
"tier2": _scorer is not None,
|
| 676 |
+
"tier3": _tier3 is not None,
|
| 677 |
+
"tier4": _cnn is not None and _cnn.is_loaded if _cnn else False,
|
| 678 |
+
"retraining_in_progress": _retrain_lock.locked(),
|
| 679 |
+
"model_version": _retrain_service.model_version if _retrain_service else 0,
|
| 680 |
+
"modules": {
|
| 681 |
+
"heuristic": _scorer is not None,
|
| 682 |
+
"bert": _bert is not None and _bert.is_loaded,
|
| 683 |
+
"bert_lazy": _bert is not None and not _bert.is_loaded,
|
| 684 |
+
"gnn": _gnn is not None and _gnn.is_loaded if _gnn else False,
|
| 685 |
+
"cnn": _cnn is not None and _cnn.is_loaded if _cnn else False,
|
| 686 |
+
"brand_hash": _brand is not None,
|
| 687 |
+
},
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
# ββ Legacy feedback endpoint (backward compat) βββββββββββββββββββββββ
|
| 692 |
+
@app.post("/feedback")
|
| 693 |
+
async def legacy_feedback_endpoint(req: dict) -> dict:
|
| 694 |
+
"""Legacy feedback endpoint for backward compatibility."""
|
| 695 |
+
return {"status": "success", "message": "Use POST /retrain for feedback-driven retraining"}
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
# ββ Run directly ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 699 |
+
# uvicorn main:app --reload --port 8000
|
manifest.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"manifest_version": 3,
|
| 3 |
+
"name": "PhishGuard AI",
|
| 4 |
+
"version": "3.0",
|
| 5 |
+
"description": "Adaptive ML-based phishing detection with feedback-driven retraining β BERT + GNN + CNN ensemble",
|
| 6 |
+
"permissions": [
|
| 7 |
+
"tabs",
|
| 8 |
+
"storage",
|
| 9 |
+
"webNavigation",
|
| 10 |
+
"alarms",
|
| 11 |
+
"notifications",
|
| 12 |
+
"activeTab",
|
| 13 |
+
"scripting"
|
| 14 |
+
],
|
| 15 |
+
"host_permissions": [
|
| 16 |
+
"<all_urls>"
|
| 17 |
+
],
|
| 18 |
+
"background": {
|
| 19 |
+
"service_worker": "background.js"
|
| 20 |
+
},
|
| 21 |
+
"content_scripts": [
|
| 22 |
+
{
|
| 23 |
+
"matches": ["<all_urls>"],
|
| 24 |
+
"js": ["content.js"],
|
| 25 |
+
"run_at": "document_idle"
|
| 26 |
+
}
|
| 27 |
+
],
|
| 28 |
+
"action": {
|
| 29 |
+
"default_popup": "popup.html",
|
| 30 |
+
"default_title": "PhishGuard AI"
|
| 31 |
+
},
|
| 32 |
+
"icons": {
|
| 33 |
+
"16": "icons/icon16.png",
|
| 34 |
+
"48": "icons/icon48.png",
|
| 35 |
+
"128": "icons/icon128.png"
|
| 36 |
+
}
|
| 37 |
+
}
|
popup.html
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>PhishGuard AI</title>
|
| 7 |
+
<style>
|
| 8 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
|
| 9 |
+
|
| 10 |
+
* { margin: 0; padding: 0; box-sizing: border-box; }
|
| 11 |
+
|
| 12 |
+
:root {
|
| 13 |
+
--bg-primary: #0F0F14;
|
| 14 |
+
--bg-secondary: #1A1A24;
|
| 15 |
+
--bg-card: #22222E;
|
| 16 |
+
--bg-hover: #2A2A38;
|
| 17 |
+
--text-primary: #EAEAF0;
|
| 18 |
+
--text-secondary: #8888A0;
|
| 19 |
+
--text-muted: #5A5A72;
|
| 20 |
+
--accent: #534AB7;
|
| 21 |
+
--accent-glow: rgba(83, 74, 183, 0.35);
|
| 22 |
+
--safe: #22C55E;
|
| 23 |
+
--safe-glow: rgba(34, 197, 94, 0.25);
|
| 24 |
+
--danger: #EF4444;
|
| 25 |
+
--danger-glow: rgba(239, 68, 68, 0.25);
|
| 26 |
+
--warning: #F59E0B;
|
| 27 |
+
--warning-glow: rgba(245, 158, 11, 0.25);
|
| 28 |
+
--border: rgba(255,255,255,0.06);
|
| 29 |
+
--radius: 12px;
|
| 30 |
+
--radius-sm: 8px;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
body {
|
| 34 |
+
width: 380px;
|
| 35 |
+
min-height: 480px;
|
| 36 |
+
max-height: 640px;
|
| 37 |
+
font-family: 'Inter', -apple-system, sans-serif;
|
| 38 |
+
background: var(--bg-primary);
|
| 39 |
+
color: var(--text-primary);
|
| 40 |
+
overflow-y: auto;
|
| 41 |
+
scrollbar-width: thin;
|
| 42 |
+
scrollbar-color: var(--bg-hover) transparent;
|
| 43 |
+
}
|
| 44 |
+
body::-webkit-scrollbar { width: 4px; }
|
| 45 |
+
body::-webkit-scrollbar-thumb { background: var(--bg-hover); border-radius: 4px; }
|
| 46 |
+
|
| 47 |
+
/* ββ Header ββββββββββββββββββββββββββββββββββββββββββββ */
|
| 48 |
+
.header {
|
| 49 |
+
display: flex; align-items: center; gap: 10px;
|
| 50 |
+
padding: 14px 20px 10px;
|
| 51 |
+
border-bottom: 1px solid var(--border);
|
| 52 |
+
}
|
| 53 |
+
.header-logo {
|
| 54 |
+
width: 28px; height: 28px;
|
| 55 |
+
background: linear-gradient(135deg, var(--accent), #7C6BDB);
|
| 56 |
+
border-radius: var(--radius-sm);
|
| 57 |
+
display: flex; align-items: center; justify-content: center;
|
| 58 |
+
font-size: 14px;
|
| 59 |
+
}
|
| 60 |
+
.header h1 { font-size: 15px; font-weight: 700; letter-spacing: -0.3px; }
|
| 61 |
+
.header h1 span { color: var(--accent); }
|
| 62 |
+
.header-badge {
|
| 63 |
+
margin-left: auto;
|
| 64 |
+
font-size: 10px; padding: 3px 8px;
|
| 65 |
+
background: var(--bg-card); border: 1px solid var(--border);
|
| 66 |
+
border-radius: 20px; color: var(--text-secondary); font-weight: 500;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
/* ββ URL Bar ββββββββββββββββββββββββββββββββββββββββββββ */
|
| 70 |
+
.url-bar {
|
| 71 |
+
padding: 8px 20px;
|
| 72 |
+
background: var(--bg-secondary);
|
| 73 |
+
border-bottom: 1px solid var(--border);
|
| 74 |
+
}
|
| 75 |
+
.url-text {
|
| 76 |
+
font-size: 11px; color: var(--text-muted);
|
| 77 |
+
overflow: hidden; text-overflow: ellipsis; white-space: nowrap;
|
| 78 |
+
font-family: 'SF Mono', 'Fira Code', monospace;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
/* ββ Loading ββββββββββββββββββββββββββββββββββββββββββββ */
|
| 82 |
+
.loading-container {
|
| 83 |
+
display: flex; flex-direction: column; align-items: center;
|
| 84 |
+
justify-content: center; padding: 40px 20px; gap: 14px;
|
| 85 |
+
}
|
| 86 |
+
.spinner {
|
| 87 |
+
width: 40px; height: 40px;
|
| 88 |
+
border: 3px solid var(--bg-hover);
|
| 89 |
+
border-top-color: var(--accent);
|
| 90 |
+
border-radius: 50%;
|
| 91 |
+
animation: spin 0.8s linear infinite;
|
| 92 |
+
}
|
| 93 |
+
@keyframes spin { to { transform: rotate(360deg); } }
|
| 94 |
+
.loading-text {
|
| 95 |
+
font-size: 13px; color: var(--text-secondary);
|
| 96 |
+
animation: pulse 1.5s ease-in-out infinite;
|
| 97 |
+
}
|
| 98 |
+
@keyframes pulse { 0%,100% { opacity: 1; } 50% { opacity: 0.5; } }
|
| 99 |
+
|
| 100 |
+
/* ββ Result Panel ββββββββββββββββββββββββββββββββββββββ */
|
| 101 |
+
.result-panel { padding: 16px 20px; }
|
| 102 |
+
|
| 103 |
+
.result-hero {
|
| 104 |
+
display: flex; align-items: center; gap: 16px;
|
| 105 |
+
margin-bottom: 16px;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.score-ring-wrap {
|
| 109 |
+
position: relative; width: 80px; height: 80px; flex-shrink: 0;
|
| 110 |
+
}
|
| 111 |
+
.score-ring-bg, .score-ring-fg {
|
| 112 |
+
fill: none; stroke-width: 6;
|
| 113 |
+
}
|
| 114 |
+
.score-ring-bg { stroke: var(--bg-hover); }
|
| 115 |
+
.score-ring-fg {
|
| 116 |
+
stroke-linecap: round;
|
| 117 |
+
transform: rotate(-90deg); transform-origin: center;
|
| 118 |
+
transition: stroke-dashoffset 1s ease, stroke 0.5s;
|
| 119 |
+
stroke-dasharray: 213; stroke-dashoffset: 213;
|
| 120 |
+
}
|
| 121 |
+
.score-label {
|
| 122 |
+
position: absolute; inset: 0;
|
| 123 |
+
display: flex; flex-direction: column;
|
| 124 |
+
align-items: center; justify-content: center;
|
| 125 |
+
}
|
| 126 |
+
.score-pct { font-size: 20px; font-weight: 700; line-height: 1; }
|
| 127 |
+
.score-sub {
|
| 128 |
+
font-size: 9px; color: var(--text-muted);
|
| 129 |
+
margin-top: 2px; text-transform: uppercase; letter-spacing: 0.5px;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
.shield-icon {
|
| 133 |
+
font-size: 28px;
|
| 134 |
+
animation: shieldPop 0.6s cubic-bezier(0.34, 1.56, 0.64, 1);
|
| 135 |
+
}
|
| 136 |
+
@keyframes shieldPop {
|
| 137 |
+
0% { transform: scale(0.3) rotate(-15deg); opacity: 0; }
|
| 138 |
+
60% { transform: scale(1.15) rotate(3deg); }
|
| 139 |
+
100% { transform: scale(1) rotate(0); opacity: 1; }
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.result-verdict { flex: 1; }
|
| 143 |
+
.verdict-label { font-size: 16px; font-weight: 700; line-height: 1.2; }
|
| 144 |
+
.verdict-detail { font-size: 11px; color: var(--text-secondary); margin-top: 3px; }
|
| 145 |
+
|
| 146 |
+
.status-safe { color: var(--safe); }
|
| 147 |
+
.status-danger { color: var(--danger); }
|
| 148 |
+
.status-warn { color: var(--warning); }
|
| 149 |
+
|
| 150 |
+
/* ββ Tier Rows βββββββββββββββββββββββββββββββββββββββββ */
|
| 151 |
+
.tier-section { margin-top: 4px; }
|
| 152 |
+
.tier-row {
|
| 153 |
+
background: var(--bg-card);
|
| 154 |
+
border: 1px solid var(--border);
|
| 155 |
+
border-radius: var(--radius-sm);
|
| 156 |
+
margin-bottom: 5px; overflow: hidden;
|
| 157 |
+
transition: border-color 0.2s;
|
| 158 |
+
}
|
| 159 |
+
.tier-row:hover { border-color: rgba(255,255,255,0.1); }
|
| 160 |
+
.tier-header {
|
| 161 |
+
display: flex; align-items: center;
|
| 162 |
+
padding: 8px 12px; cursor: pointer;
|
| 163 |
+
user-select: none; gap: 8px;
|
| 164 |
+
}
|
| 165 |
+
.tier-dot { width: 7px; height: 7px; border-radius: 50%; flex-shrink: 0; }
|
| 166 |
+
.tier-name { font-size: 11px; font-weight: 600; flex: 1; }
|
| 167 |
+
.tier-score {
|
| 168 |
+
font-size: 11px; font-weight: 600;
|
| 169 |
+
font-family: 'SF Mono', 'Fira Code', monospace;
|
| 170 |
+
}
|
| 171 |
+
.tier-chevron {
|
| 172 |
+
font-size: 9px; color: var(--text-muted);
|
| 173 |
+
transition: transform 0.2s;
|
| 174 |
+
}
|
| 175 |
+
.tier-row.open .tier-chevron { transform: rotate(180deg); }
|
| 176 |
+
.tier-body {
|
| 177 |
+
max-height: 0; overflow: hidden;
|
| 178 |
+
transition: max-height 0.3s ease; padding: 0 12px;
|
| 179 |
+
}
|
| 180 |
+
.tier-row.open .tier-body { max-height: 200px; padding: 4px 12px 10px; }
|
| 181 |
+
.tier-detail { font-size: 10px; color: var(--text-secondary); line-height: 1.6; }
|
| 182 |
+
.flag-badge {
|
| 183 |
+
display: inline-block; padding: 1px 6px;
|
| 184 |
+
background: rgba(239,68,68,0.12); color: var(--danger);
|
| 185 |
+
border-radius: 4px; font-size: 10px; margin: 2px 2px 2px 0;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/* ββ Feedback Section ββββββββββββββββββββββββββββββββββ */
|
| 189 |
+
.feedback-section {
|
| 190 |
+
padding: 0 20px 12px; margin-top: 8px;
|
| 191 |
+
}
|
| 192 |
+
.feedback-prompt {
|
| 193 |
+
font-size: 12px; color: var(--text-secondary);
|
| 194 |
+
margin-bottom: 8px; text-align: center;
|
| 195 |
+
}
|
| 196 |
+
.feedback-buttons { display: flex; gap: 8px; }
|
| 197 |
+
.fb-btn {
|
| 198 |
+
flex: 1; padding: 8px 0;
|
| 199 |
+
border: 1px solid var(--border); border-radius: var(--radius-sm);
|
| 200 |
+
background: var(--bg-card); color: var(--text-primary);
|
| 201 |
+
font-size: 13px; font-weight: 600; cursor: pointer;
|
| 202 |
+
transition: all 0.2s; font-family: inherit;
|
| 203 |
+
}
|
| 204 |
+
.fb-btn:hover { background: var(--bg-hover); }
|
| 205 |
+
.fb-btn-correct:hover { border-color: var(--safe); background: rgba(34,197,94,0.08); }
|
| 206 |
+
.fb-btn-wrong:hover { border-color: var(--danger); background: rgba(239,68,68,0.08); }
|
| 207 |
+
.fb-btn.selected {
|
| 208 |
+
opacity: 1 !important;
|
| 209 |
+
}
|
| 210 |
+
.fb-btn.dimmed {
|
| 211 |
+
opacity: 0.3; pointer-events: none;
|
| 212 |
+
}
|
| 213 |
+
.fb-btn-correct.selected {
|
| 214 |
+
border-color: var(--safe); background: rgba(34,197,94,0.15);
|
| 215 |
+
color: var(--safe);
|
| 216 |
+
}
|
| 217 |
+
.fb-btn-wrong.selected {
|
| 218 |
+
border-color: var(--danger); background: rgba(239,68,68,0.15);
|
| 219 |
+
color: var(--danger);
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
.thank-you {
|
| 223 |
+
display: none; text-align: center; padding: 10px;
|
| 224 |
+
font-size: 12px; color: var(--safe);
|
| 225 |
+
animation: slideDown 0.3s ease;
|
| 226 |
+
}
|
| 227 |
+
.thank-you.show { display: block; }
|
| 228 |
+
@keyframes slideDown {
|
| 229 |
+
from { opacity: 0; transform: translateY(-6px); }
|
| 230 |
+
to { opacity: 1; transform: translateY(0); }
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/* ββ Retraining Status βββββββββββββββββββββββββββββββββ */
|
| 234 |
+
.retrain-section {
|
| 235 |
+
padding: 8px 20px 12px;
|
| 236 |
+
border-top: 1px solid var(--border);
|
| 237 |
+
margin-top: 4px;
|
| 238 |
+
}
|
| 239 |
+
.retrain-row {
|
| 240 |
+
display: flex; align-items: center; gap: 6px;
|
| 241 |
+
font-size: 11px; color: var(--text-muted);
|
| 242 |
+
margin-bottom: 4px;
|
| 243 |
+
}
|
| 244 |
+
.retrain-row .icon { font-size: 12px; }
|
| 245 |
+
.retrain-progress {
|
| 246 |
+
height: 3px; background: var(--bg-hover);
|
| 247 |
+
border-radius: 2px; margin: 6px 0 4px;
|
| 248 |
+
overflow: hidden;
|
| 249 |
+
}
|
| 250 |
+
.retrain-progress-bar {
|
| 251 |
+
height: 100%; background: linear-gradient(90deg, var(--accent), #7C6BDB);
|
| 252 |
+
border-radius: 2px; transition: width 0.5s ease;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/* ββ Session Stats βββββββββββββββββββββββββββββββββββββ */
|
| 256 |
+
.stats-row {
|
| 257 |
+
display: flex; justify-content: space-between;
|
| 258 |
+
padding: 6px 20px;
|
| 259 |
+
border-top: 1px solid var(--border);
|
| 260 |
+
font-size: 11px; color: var(--text-muted);
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
/* ββ Blocked Overlay βββββββββββββββββββββββββββββββββββ */
|
| 264 |
+
.blocked-overlay {
|
| 265 |
+
display: none; padding: 28px 24px; text-align: center;
|
| 266 |
+
}
|
| 267 |
+
.blocked-overlay.show {
|
| 268 |
+
display: flex; flex-direction: column;
|
| 269 |
+
align-items: center; gap: 10px;
|
| 270 |
+
}
|
| 271 |
+
.blocked-shield { font-size: 48px; animation: shieldPop 0.6s ease; }
|
| 272 |
+
.blocked-title { font-size: 18px; font-weight: 700; color: var(--danger); }
|
| 273 |
+
.blocked-url {
|
| 274 |
+
font-size: 11px; color: var(--text-muted);
|
| 275 |
+
word-break: break-all; max-width: 300px;
|
| 276 |
+
}
|
| 277 |
+
.blocked-method { font-size: 12px; color: var(--text-secondary); }
|
| 278 |
+
.proceed-btn {
|
| 279 |
+
margin-top: 8px; padding: 8px 20px;
|
| 280 |
+
background: transparent; border: 1px solid rgba(239,68,68,0.3);
|
| 281 |
+
border-radius: var(--radius-sm); color: var(--text-secondary);
|
| 282 |
+
font-size: 12px; cursor: pointer; font-family: inherit;
|
| 283 |
+
transition: all 0.2s;
|
| 284 |
+
}
|
| 285 |
+
.proceed-btn:hover {
|
| 286 |
+
background: rgba(239,68,68,0.08);
|
| 287 |
+
border-color: var(--danger); color: var(--danger);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
.offline-banner {
|
| 291 |
+
display: none; padding: 8px 16px;
|
| 292 |
+
background: rgba(245, 158, 11, 0.08);
|
| 293 |
+
border: 1px solid rgba(245, 158, 11, 0.2);
|
| 294 |
+
border-radius: var(--radius-sm);
|
| 295 |
+
margin: 8px 20px 0; font-size: 11px;
|
| 296 |
+
color: var(--warning); text-align: center;
|
| 297 |
+
}
|
| 298 |
+
.offline-banner.show { display: block; }
|
| 299 |
+
</style>
|
| 300 |
+
</head>
|
| 301 |
+
<body>
|
| 302 |
+
|
| 303 |
+
<!-- Header -->
|
| 304 |
+
<div class="header">
|
| 305 |
+
<div class="header-logo">π‘οΈ</div>
|
| 306 |
+
<h1>Phish<span>Guard</span> AI</h1>
|
| 307 |
+
<span class="header-badge" id="versionBadge">v3.0</span>
|
| 308 |
+
</div>
|
| 309 |
+
|
| 310 |
+
<!-- URL Bar -->
|
| 311 |
+
<div class="url-bar">
|
| 312 |
+
<div class="url-text" id="currentUrl">Analyzing...</div>
|
| 313 |
+
</div>
|
| 314 |
+
|
| 315 |
+
<!-- Server offline banner -->
|
| 316 |
+
<div class="offline-banner" id="offlineBanner">
|
| 317 |
+
β οΈ Server offline β local heuristic only
|
| 318 |
+
</div>
|
| 319 |
+
|
| 320 |
+
<!-- Loading State -->
|
| 321 |
+
<div class="loading-container" id="loadingState">
|
| 322 |
+
<div class="spinner"></div>
|
| 323 |
+
<div class="loading-text">Analyzing with AI ensemble...</div>
|
| 324 |
+
</div>
|
| 325 |
+
|
| 326 |
+
<!-- Result Panel -->
|
| 327 |
+
<div class="result-panel" id="resultPanel" style="display:none;">
|
| 328 |
+
<div class="result-hero">
|
| 329 |
+
<div class="score-ring-wrap">
|
| 330 |
+
<svg width="80" height="80" viewBox="0 0 80 80">
|
| 331 |
+
<circle class="score-ring-bg" cx="40" cy="40" r="34" />
|
| 332 |
+
<circle class="score-ring-fg" id="scoreRing" cx="40" cy="40" r="34" />
|
| 333 |
+
</svg>
|
| 334 |
+
<div class="score-label">
|
| 335 |
+
<div class="score-pct" id="scorePct">0%</div>
|
| 336 |
+
<div class="score-sub" id="scoreSub">RISK</div>
|
| 337 |
+
</div>
|
| 338 |
+
</div>
|
| 339 |
+
<div style="display:flex; flex-direction:column; align-items:center; gap:4px">
|
| 340 |
+
<div class="shield-icon" id="shieldIcon">π‘οΈ</div>
|
| 341 |
+
</div>
|
| 342 |
+
<div class="result-verdict">
|
| 343 |
+
<div class="verdict-label" id="verdictLabel">Analyzing</div>
|
| 344 |
+
<div class="verdict-detail" id="verdictDetail">Please wait...</div>
|
| 345 |
+
</div>
|
| 346 |
+
</div>
|
| 347 |
+
|
| 348 |
+
<!-- Tier Rows -->
|
| 349 |
+
<div class="tier-section" id="tierSection">
|
| 350 |
+
<div class="tier-row" data-tier="1">
|
| 351 |
+
<div class="tier-header" onclick="toggleTier(this)">
|
| 352 |
+
<div class="tier-dot" id="t1Dot"></div>
|
| 353 |
+
<div class="tier-name">Tier 1 Β· Whitelist</div>
|
| 354 |
+
<div class="tier-score" id="t1Score">β</div>
|
| 355 |
+
<div class="tier-chevron">βΌ</div>
|
| 356 |
+
</div>
|
| 357 |
+
<div class="tier-body"><div class="tier-detail" id="t1Detail">O(1) domain lookup</div></div>
|
| 358 |
+
</div>
|
| 359 |
+
<div class="tier-row" data-tier="2">
|
| 360 |
+
<div class="tier-header" onclick="toggleTier(this)">
|
| 361 |
+
<div class="tier-dot" id="t2Dot"></div>
|
| 362 |
+
<div class="tier-name">Tier 2 Β· Heuristics</div>
|
| 363 |
+
<div class="tier-score" id="t2Score">β</div>
|
| 364 |
+
<div class="tier-chevron">βΌ</div>
|
| 365 |
+
</div>
|
| 366 |
+
<div class="tier-body"><div class="tier-detail" id="t2Detail">15 regex/math signals</div></div>
|
| 367 |
+
</div>
|
| 368 |
+
<div class="tier-row" data-tier="3">
|
| 369 |
+
<div class="tier-header" onclick="toggleTier(this)">
|
| 370 |
+
<div class="tier-dot" id="t3Dot"></div>
|
| 371 |
+
<div class="tier-name">Tier 3 Β· BERT + GNN</div>
|
| 372 |
+
<div class="tier-score" id="t3Score">β</div>
|
| 373 |
+
<div class="tier-chevron">βΌ</div>
|
| 374 |
+
</div>
|
| 375 |
+
<div class="tier-body"><div class="tier-detail" id="t3Detail">Parallel NLP + graph analysis</div></div>
|
| 376 |
+
</div>
|
| 377 |
+
<div class="tier-row" data-tier="4">
|
| 378 |
+
<div class="tier-header" onclick="toggleTier(this)">
|
| 379 |
+
<div class="tier-dot" id="t4Dot"></div>
|
| 380 |
+
<div class="tier-name">Tier 4 Β· CNN Visual</div>
|
| 381 |
+
<div class="tier-score" id="t4Score">β</div>
|
| 382 |
+
<div class="tier-chevron">βΌ</div>
|
| 383 |
+
</div>
|
| 384 |
+
<div class="tier-body"><div class="tier-detail" id="t4Detail">Screenshot + brand detection</div></div>
|
| 385 |
+
</div>
|
| 386 |
+
</div>
|
| 387 |
+
</div>
|
| 388 |
+
|
| 389 |
+
<!-- Feedback Section -->
|
| 390 |
+
<div class="feedback-section" id="feedbackSection" style="display:none;">
|
| 391 |
+
<div class="feedback-prompt">Was this correct?</div>
|
| 392 |
+
<div class="feedback-buttons">
|
| 393 |
+
<button class="fb-btn fb-btn-correct" id="btnCorrect">π Correct</button>
|
| 394 |
+
<button class="fb-btn fb-btn-wrong" id="btnWrong">οΏ½οΏ½οΏ½ Incorrect</button>
|
| 395 |
+
</div>
|
| 396 |
+
<div class="thank-you" id="thankYou">β Thanks! Helps us improve π―</div>
|
| 397 |
+
</div>
|
| 398 |
+
|
| 399 |
+
<!-- Retraining Status -->
|
| 400 |
+
<div class="retrain-section" id="retrainSection" style="display:none;">
|
| 401 |
+
<div class="retrain-row">
|
| 402 |
+
<span class="icon">π</span>
|
| 403 |
+
<span id="retrainStatus">Next retrain: calculating...</span>
|
| 404 |
+
</div>
|
| 405 |
+
<div class="retrain-progress">
|
| 406 |
+
<div class="retrain-progress-bar" id="retrainProgressBar" style="width: 0%"></div>
|
| 407 |
+
</div>
|
| 408 |
+
<div class="retrain-row">
|
| 409 |
+
<span class="icon">π</span>
|
| 410 |
+
<span id="retrainLast">No retraining yet</span>
|
| 411 |
+
</div>
|
| 412 |
+
</div>
|
| 413 |
+
|
| 414 |
+
<!-- Session Stats -->
|
| 415 |
+
<div class="stats-row" id="statsRow" style="display:none;">
|
| 416 |
+
<span id="statScanned">π 0 scanned</span>
|
| 417 |
+
<span id="statFeedback">π¬ 0 feedback</span>
|
| 418 |
+
<span id="statVersion">π·οΈ v0</span>
|
| 419 |
+
</div>
|
| 420 |
+
|
| 421 |
+
<!-- Blocked Page Overlay -->
|
| 422 |
+
<div class="blocked-overlay" id="blockedOverlay">
|
| 423 |
+
<div class="blocked-shield">π¨</div>
|
| 424 |
+
<div class="blocked-title">Phishing Detected!</div>
|
| 425 |
+
<div class="blocked-url" id="blockedUrl"></div>
|
| 426 |
+
<div class="blocked-method" id="blockedMethod"></div>
|
| 427 |
+
<button class="proceed-btn" id="proceedBtn">Proceed Anyway (Unsafe)</button>
|
| 428 |
+
</div>
|
| 429 |
+
|
| 430 |
+
<script src="popup.js"></script>
|
| 431 |
+
</body>
|
| 432 |
+
</html>
|
popup.js
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// ============================================================
|
| 2 |
+
// PhishGuard AI - popup.js
|
| 3 |
+
// Popup logic: displays verdict, feedback buttons, retraining
|
| 4 |
+
// status, and session stats.
|
| 5 |
+
// ============================================================
|
| 6 |
+
|
| 7 |
+
(function() {
|
| 8 |
+
"use strict";
|
| 9 |
+
|
| 10 |
+
// ββ DOM Elements ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 11 |
+
const $id = id => document.getElementById(id);
|
| 12 |
+
const loadingState = $id("loadingState");
|
| 13 |
+
const resultPanel = $id("resultPanel");
|
| 14 |
+
const feedbackSection = $id("feedbackSection");
|
| 15 |
+
const retrainSection = $id("retrainSection");
|
| 16 |
+
const statsRow = $id("statsRow");
|
| 17 |
+
const blockedOverlay = $id("blockedOverlay");
|
| 18 |
+
const offlineBanner = $id("offlineBanner");
|
| 19 |
+
|
| 20 |
+
let currentResult = null;
|
| 21 |
+
let currentUrlHash = null;
|
| 22 |
+
let feedbackGiven = false;
|
| 23 |
+
let feedbackTimeout = null;
|
| 24 |
+
let countdownInterval = null;
|
| 25 |
+
|
| 26 |
+
// ββ Init ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
async function init() {
|
| 28 |
+
// Check if this is a blocked page redirect
|
| 29 |
+
const params = new URLSearchParams(window.location.search);
|
| 30 |
+
if (params.get("blocked") === "1") {
|
| 31 |
+
showBlockedPage(params);
|
| 32 |
+
return;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
// Get active tab
|
| 36 |
+
const [tab] = await chrome.tabs.query({ active: true, currentWindow: true });
|
| 37 |
+
if (!tab?.url || !tab.url.startsWith("http")) {
|
| 38 |
+
showResult({
|
| 39 |
+
status: "safe", tier: 0, method: "internal",
|
| 40 |
+
confidence: 0, url: tab?.url || "N/A"
|
| 41 |
+
});
|
| 42 |
+
return;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
$id("currentUrl").textContent = tab.url;
|
| 46 |
+
|
| 47 |
+
// Try per-tab cache first (instant)
|
| 48 |
+
chrome.runtime.sendMessage(
|
| 49 |
+
{ type: "get_tab_result", tabId: tab.id },
|
| 50 |
+
response => {
|
| 51 |
+
if (response?.result) {
|
| 52 |
+
showResult(response.result);
|
| 53 |
+
} else {
|
| 54 |
+
// Fallback to chrome.storage
|
| 55 |
+
chrome.storage.local.get("lastResult", data => {
|
| 56 |
+
if (data.lastResult && data.lastResult.url === tab.url) {
|
| 57 |
+
showResult(data.lastResult);
|
| 58 |
+
} else {
|
| 59 |
+
loadingState.style.display = "flex";
|
| 60 |
+
}
|
| 61 |
+
});
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
);
|
| 65 |
+
|
| 66 |
+
// Load status
|
| 67 |
+
loadStatus();
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
// ββ Show Result βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
function showResult(result) {
|
| 72 |
+
currentResult = result;
|
| 73 |
+
loadingState.style.display = "none";
|
| 74 |
+
resultPanel.style.display = "block";
|
| 75 |
+
feedbackSection.style.display = "block";
|
| 76 |
+
retrainSection.style.display = "block";
|
| 77 |
+
statsRow.style.display = "flex";
|
| 78 |
+
|
| 79 |
+
const isBlocked = result.status === "blocked" || result.is_phishing;
|
| 80 |
+
const isWarn = !isBlocked && (result.confidence || 0) >= 0.4;
|
| 81 |
+
const confidence = Math.round((result.confidence || 0) * 100);
|
| 82 |
+
const tier = result.tier || 0;
|
| 83 |
+
|
| 84 |
+
// Score ring
|
| 85 |
+
const ring = $id("scoreRing");
|
| 86 |
+
const circumference = 2 * Math.PI * 34; // r=34
|
| 87 |
+
const offset = circumference - (confidence / 100) * circumference;
|
| 88 |
+
ring.style.strokeDasharray = circumference;
|
| 89 |
+
|
| 90 |
+
setTimeout(() => {
|
| 91 |
+
ring.style.strokeDashoffset = offset;
|
| 92 |
+
ring.style.stroke = isBlocked ? "var(--danger)" :
|
| 93 |
+
isWarn ? "var(--warning)" : "var(--safe)";
|
| 94 |
+
}, 100);
|
| 95 |
+
|
| 96 |
+
$id("scorePct").textContent = confidence + "%";
|
| 97 |
+
$id("scorePct").className = `score-pct ${isBlocked ? "status-danger" : isWarn ? "status-warn" : "status-safe"}`;
|
| 98 |
+
$id("scoreSub").textContent = isBlocked ? "THREAT" : "RISK";
|
| 99 |
+
|
| 100 |
+
// Shield
|
| 101 |
+
$id("shieldIcon").textContent = isBlocked ? "π¨" : isWarn ? "β οΈ" : "β
";
|
| 102 |
+
|
| 103 |
+
// Verdict text
|
| 104 |
+
$id("verdictLabel").textContent = isBlocked ? "PHISHING DETECTED" :
|
| 105 |
+
isWarn ? "SUSPICIOUS" : "SAFE";
|
| 106 |
+
$id("verdictLabel").className = `verdict-label ${isBlocked ? "status-danger" : isWarn ? "status-warn" : "status-safe"}`;
|
| 107 |
+
|
| 108 |
+
const methodNames = {
|
| 109 |
+
"whitelist": "Whitelist (Tier 1)",
|
| 110 |
+
"heuristic": "Heuristic Engine (Tier 2)",
|
| 111 |
+
"heuristic-fallback": "Heuristic Fallback",
|
| 112 |
+
"bert_gnn_ensemble": "BERT + GNN Ensemble (Tier 3)",
|
| 113 |
+
"full_ensemble_bert_gnn_cnn": "Full ML Ensemble (Tier 4)",
|
| 114 |
+
"ensemble_with_visual": "Ensemble + Visual (Tier 4)",
|
| 115 |
+
"user-override": "User Override",
|
| 116 |
+
};
|
| 117 |
+
const methodText = methodNames[result.method] || result.method || "Unknown";
|
| 118 |
+
$id("verdictDetail").textContent = `${methodText} Β· Confidence: ${confidence}%`;
|
| 119 |
+
|
| 120 |
+
// Tier dots and scores
|
| 121 |
+
updateTierRow(1, tier >= 1 ? "checked" : "pending", tier === 1 ? "SAFE β" : "Miss β");
|
| 122 |
+
updateTierRow(2, tier >= 2 ? (isBlocked && tier === 2 ? "blocked" : "checked") : "pending",
|
| 123 |
+
result.heuristic_score != null ? `${result.heuristic_score}/100` : "β");
|
| 124 |
+
updateTierRow(3, tier >= 3 ? (isBlocked && tier === 3 ? "blocked" : "checked") : "pending",
|
| 125 |
+
result.details?.tier3_score != null ? (result.details.tier3_score * 100).toFixed(0) + "%" : "β");
|
| 126 |
+
updateTierRow(4, tier >= 4 ? (isBlocked && tier === 4 ? "blocked" : "checked") : "pending",
|
| 127 |
+
result.details?.tier4_score != null ? (result.details.tier4_score * 100).toFixed(0) + "%" : "β");
|
| 128 |
+
|
| 129 |
+
// Tier 2 details β show triggered signals
|
| 130 |
+
if (result.signals && result.signals.length > 0) {
|
| 131 |
+
const badges = result.signals.map(s => `<span class="flag-badge">${s}</span>`).join(" ");
|
| 132 |
+
$id("t2Detail").innerHTML = `Signals triggered:<br>${badges}`;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// Compute URL hash for feedback
|
| 136 |
+
computeUrlHash(result.url);
|
| 137 |
+
|
| 138 |
+
// Check if we already gave feedback
|
| 139 |
+
checkExistingFeedback();
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
function updateTierRow(tier, status, scoreText) {
|
| 143 |
+
const dot = $id(`t${tier}Dot`);
|
| 144 |
+
const score = $id(`t${tier}Score`);
|
| 145 |
+
|
| 146 |
+
const colors = {
|
| 147 |
+
checked: "var(--safe)",
|
| 148 |
+
blocked: "var(--danger)",
|
| 149 |
+
pending: "var(--text-muted)",
|
| 150 |
+
};
|
| 151 |
+
dot.style.background = colors[status] || colors.pending;
|
| 152 |
+
score.textContent = scoreText;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// ββ Blocked Page ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 156 |
+
function showBlockedPage(params) {
|
| 157 |
+
loadingState.style.display = "none";
|
| 158 |
+
blockedOverlay.classList.add("show");
|
| 159 |
+
|
| 160 |
+
const url = decodeURIComponent(params.get("url") || "");
|
| 161 |
+
const score = params.get("score") || "0";
|
| 162 |
+
const method = decodeURIComponent(params.get("method") || "");
|
| 163 |
+
|
| 164 |
+
$id("blockedUrl").textContent = url;
|
| 165 |
+
$id("blockedMethod").textContent = `Detection: ${method} Β· Risk: ${score}%`;
|
| 166 |
+
$id("currentUrl").textContent = url;
|
| 167 |
+
|
| 168 |
+
currentResult = { url, status: "blocked", confidence: parseInt(score) / 100, method };
|
| 169 |
+
computeUrlHash(url);
|
| 170 |
+
|
| 171 |
+
// Show feedback for blocked pages too
|
| 172 |
+
feedbackSection.style.display = "block";
|
| 173 |
+
retrainSection.style.display = "block";
|
| 174 |
+
statsRow.style.display = "flex";
|
| 175 |
+
loadStatus();
|
| 176 |
+
|
| 177 |
+
$id("proceedBtn").onclick = () => {
|
| 178 |
+
chrome.runtime.sendMessage({ type: "whitelist_url", url }, () => {
|
| 179 |
+
chrome.tabs.update({ url });
|
| 180 |
+
});
|
| 181 |
+
};
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
// ββ Feedback ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
+
async function computeUrlHash(url) {
|
| 186 |
+
if (!url) return;
|
| 187 |
+
const encoded = new TextEncoder().encode(url);
|
| 188 |
+
const hash = await crypto.subtle.digest("SHA-256", encoded);
|
| 189 |
+
currentUrlHash = Array.from(new Uint8Array(hash))
|
| 190 |
+
.map(b => b.toString(16).padStart(2, "0")).join("");
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
function submitFeedback(feedback) {
|
| 194 |
+
if (feedbackGiven || !currentUrlHash) return;
|
| 195 |
+
feedbackGiven = true;
|
| 196 |
+
|
| 197 |
+
chrome.runtime.sendMessage({
|
| 198 |
+
type: "submit_feedback",
|
| 199 |
+
url_hash: currentUrlHash,
|
| 200 |
+
feedback: feedback,
|
| 201 |
+
}, response => {
|
| 202 |
+
if (response?.success) {
|
| 203 |
+
// Highlight selected, dim other
|
| 204 |
+
const correctBtn = $id("btnCorrect");
|
| 205 |
+
const wrongBtn = $id("btnWrong");
|
| 206 |
+
|
| 207 |
+
if (feedback === "correct") {
|
| 208 |
+
correctBtn.classList.add("selected");
|
| 209 |
+
wrongBtn.classList.add("dimmed");
|
| 210 |
+
} else {
|
| 211 |
+
wrongBtn.classList.add("selected");
|
| 212 |
+
correctBtn.classList.add("dimmed");
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
$id("thankYou").classList.add("show");
|
| 216 |
+
|
| 217 |
+
// Allow changing within 5 minutes
|
| 218 |
+
feedbackTimeout = setTimeout(() => {
|
| 219 |
+
feedbackGiven = false;
|
| 220 |
+
correctBtn.classList.remove("selected", "dimmed");
|
| 221 |
+
wrongBtn.classList.remove("selected", "dimmed");
|
| 222 |
+
}, 5 * 60 * 1000);
|
| 223 |
+
}
|
| 224 |
+
});
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
function checkExistingFeedback() {
|
| 228 |
+
// Check if feedback was already given for this URL
|
| 229 |
+
chrome.storage.local.get("phishguard_feedback_queue", data => {
|
| 230 |
+
const queue = data.phishguard_feedback_queue || [];
|
| 231 |
+
const record = queue.find(r => r.url_hash === currentUrlHash);
|
| 232 |
+
if (record?.user_feedback) {
|
| 233 |
+
feedbackGiven = true;
|
| 234 |
+
const correctBtn = $id("btnCorrect");
|
| 235 |
+
const wrongBtn = $id("btnWrong");
|
| 236 |
+
if (record.user_feedback === "correct") {
|
| 237 |
+
correctBtn.classList.add("selected");
|
| 238 |
+
wrongBtn.classList.add("dimmed");
|
| 239 |
+
} else {
|
| 240 |
+
wrongBtn.classList.add("selected");
|
| 241 |
+
correctBtn.classList.add("dimmed");
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
});
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// ββ Retraining Status βββββββββββββββββββββββββββββββββββββββββ
|
| 248 |
+
function loadStatus() {
|
| 249 |
+
chrome.runtime.sendMessage({ type: "get_status" }, status => {
|
| 250 |
+
if (!status) return;
|
| 251 |
+
|
| 252 |
+
// Stats row
|
| 253 |
+
$id("statScanned").textContent = `π ${status.scan_count} scanned`;
|
| 254 |
+
$id("statFeedback").textContent = `π¬ ${status.labeled_count || 0} labeled`;
|
| 255 |
+
$id("statVersion").textContent = `π·οΈ v${status.model_version}`;
|
| 256 |
+
$id("versionBadge").textContent = `v${status.model_version || "3.0"}`;
|
| 257 |
+
|
| 258 |
+
// Retrain progress
|
| 259 |
+
const urlsRemaining = status.next_retrain_urls_remaining || 50;
|
| 260 |
+
const progress = Math.round(((50 - urlsRemaining) / 50) * 100);
|
| 261 |
+
$id("retrainProgressBar").style.width = `${progress}%`;
|
| 262 |
+
|
| 263 |
+
// Retrain status text
|
| 264 |
+
const timeMs = status.next_retrain_time_remaining_ms || 0;
|
| 265 |
+
const hours = Math.floor(timeMs / 3600000);
|
| 266 |
+
const mins = Math.floor((timeMs % 3600000) / 60000);
|
| 267 |
+
|
| 268 |
+
const labeledNeeded = status.min_labeled_needed || 0;
|
| 269 |
+
if (labeledNeeded > 0) {
|
| 270 |
+
$id("retrainStatus").textContent =
|
| 271 |
+
`Need ${labeledNeeded} more feedback to retrain`;
|
| 272 |
+
} else {
|
| 273 |
+
$id("retrainStatus").textContent =
|
| 274 |
+
`Next retrain: ${urlsRemaining} URLs or ${hours}h ${mins}m`;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// Last retrain info
|
| 278 |
+
if (status.last_retrain_ts) {
|
| 279 |
+
const ago = timeSince(new Date(status.last_retrain_ts));
|
| 280 |
+
$id("retrainLast").textContent = `Last retrain: ${ago} ago`;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
// Start countdown
|
| 284 |
+
startCountdown(timeMs);
|
| 285 |
+
});
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
function startCountdown(initialMs) {
|
| 289 |
+
if (countdownInterval) clearInterval(countdownInterval);
|
| 290 |
+
let remaining = initialMs;
|
| 291 |
+
|
| 292 |
+
countdownInterval = setInterval(() => {
|
| 293 |
+
remaining -= 1000;
|
| 294 |
+
if (remaining <= 0) {
|
| 295 |
+
clearInterval(countdownInterval);
|
| 296 |
+
$id("retrainStatus").textContent = "Retrain pending...";
|
| 297 |
+
return;
|
| 298 |
+
}
|
| 299 |
+
const h = Math.floor(remaining / 3600000);
|
| 300 |
+
const m = Math.floor((remaining % 3600000) / 60000);
|
| 301 |
+
const s = Math.floor((remaining % 60000) / 1000);
|
| 302 |
+
|
| 303 |
+
// Only update the time portion if it's the time display
|
| 304 |
+
const el = $id("retrainStatus");
|
| 305 |
+
if (el.textContent.includes("URLs or")) {
|
| 306 |
+
const parts = el.textContent.split(" or ");
|
| 307 |
+
el.textContent = `${parts[0]} or ${h}h ${m}m ${s}s`;
|
| 308 |
+
}
|
| 309 |
+
}, 1000);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
function timeSince(date) {
|
| 313 |
+
const seconds = Math.floor((Date.now() - date.getTime()) / 1000);
|
| 314 |
+
if (seconds < 60) return `${seconds}s`;
|
| 315 |
+
if (seconds < 3600) return `${Math.floor(seconds / 60)}m`;
|
| 316 |
+
if (seconds < 86400) return `${Math.floor(seconds / 3600)}h`;
|
| 317 |
+
return `${Math.floor(seconds / 86400)}d`;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
// ββ Tier Row Toggle βββββββββββββββββββββββββββββββββββββββββββ
|
| 321 |
+
window.toggleTier = function(header) {
|
| 322 |
+
const row = header.parentElement;
|
| 323 |
+
row.classList.toggle("open");
|
| 324 |
+
};
|
| 325 |
+
|
| 326 |
+
// ββ Event Listeners βββββββββββββββββββββββββββββββββββββββββββ
|
| 327 |
+
$id("btnCorrect").addEventListener("click", () => submitFeedback("correct"));
|
| 328 |
+
$id("btnWrong").addEventListener("click", () => submitFeedback("incorrect"));
|
| 329 |
+
|
| 330 |
+
// ββ Start βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 331 |
+
init();
|
| 332 |
+
})();
|
render.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# render.yaml β Render.com deployment config (architecture doc 5.2)
|
| 3 |
+
#
|
| 4 |
+
# PLAYWRIGHT ON RENDER:
|
| 5 |
+
# To enable Tier 4 visual analysis (Playwright + Chromium), you need:
|
| 6 |
+
#
|
| 7 |
+
# 1. Add ENABLE_VISUAL_TIER=1 env var below
|
| 8 |
+
# 2. Switch from python:3.10-slim to a full image in Dockerfile
|
| 9 |
+
# OR add Chromium system deps to the Dockerfile:
|
| 10 |
+
# apt-get install -y libnss3 libatk1.0-0 libatk-bridge2.0-0 \
|
| 11 |
+
# libcups2 libxkbcommon0 libgbm1 libpango-1.0-0 \
|
| 12 |
+
# libcairo2 libasound2 libxdamage1 libxrandr2 libxfixes3
|
| 13 |
+
# 3. Add to Dockerfile after pip install:
|
| 14 |
+
# RUN pip install playwright && playwright install chromium
|
| 15 |
+
#
|
| 16 |
+
# NOTE: Playwright + Chromium adds ~400MB to the Docker image.
|
| 17 |
+
# On the free tier (512MB RAM), this may cause OOM.
|
| 18 |
+
# Only enable if you have a paid plan with >= 1GB RAM.
|
| 19 |
+
# ============================================================
|
| 20 |
+
|
| 21 |
+
services:
|
| 22 |
+
- type: web
|
| 23 |
+
name: phishguard-api
|
| 24 |
+
runtime: docker
|
| 25 |
+
dockerfilePath: ./Dockerfile
|
| 26 |
+
plan: free
|
| 27 |
+
healthCheckPath: /health
|
| 28 |
+
autoDeploy: true
|
| 29 |
+
envVars:
|
| 30 |
+
- key: PORT
|
| 31 |
+
value: "8000"
|
| 32 |
+
# Uncomment below to enable Tier 4 visual analysis (needs Playwright in Dockerfile)
|
| 33 |
+
# - key: ENABLE_VISUAL_TIER
|
| 34 |
+
# value: "1"
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.111.0
|
| 2 |
+
uvicorn[standard]==0.29.0
|
| 3 |
+
transformers==4.40.0
|
| 4 |
+
torch==2.2.2
|
| 5 |
+
torch-geometric==2.5.2
|
| 6 |
+
torchvision==0.17.2
|
| 7 |
+
playwright==1.44.0
|
| 8 |
+
pillow==10.3.0
|
| 9 |
+
scikit-learn==1.4.2
|
| 10 |
+
pandas==2.2.2
|
| 11 |
+
numpy==1.26.4
|
| 12 |
+
httpx==0.27.0
|
| 13 |
+
imagehash==4.3.1
|
| 14 |
+
requests==2.31.0
|
| 15 |
+
aiohttp==3.9.5
|
| 16 |
+
aiofiles==23.2.1
|
| 17 |
+
python-multipart==0.0.9
|
| 18 |
+
apscheduler==3.10.4
|
retraining_service.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - retraining_service.py
|
| 3 |
+
# Incremental retraining service for all 3 ML models.
|
| 4 |
+
#
|
| 5 |
+
# Receives labeled feedback samples from the Chrome extension.
|
| 6 |
+
# Runs parallel incremental updates for BERT, GNN, and CNN.
|
| 7 |
+
# Tracks model version and accuracy deltas.
|
| 8 |
+
# Supports hot-reload of all models without server restart.
|
| 9 |
+
# ============================================================
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Dict, List, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("phishguard.retrain")
|
| 22 |
+
|
| 23 |
+
DATA_DIR = Path(__file__).parent / "data"
|
| 24 |
+
MODEL_VERSION_PATH = DATA_DIR / "model_version.json"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class FeedbackRecord:
|
| 29 |
+
"""A single feedback record from the Chrome extension."""
|
| 30 |
+
url: str
|
| 31 |
+
verdict: str # "phishing" or "safe"
|
| 32 |
+
confidence: float = 0.0
|
| 33 |
+
tier_used: int = 0
|
| 34 |
+
heuristic_score: int = 0
|
| 35 |
+
signals: List[str] = field(default_factory=list)
|
| 36 |
+
user_feedback: Optional[str] = None # "correct" or "incorrect"
|
| 37 |
+
timestamp: str = ""
|
| 38 |
+
feedback_ts: Optional[str] = None
|
| 39 |
+
url_hash: str = ""
|
| 40 |
+
session_id: str = ""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class RetrainResult:
|
| 45 |
+
"""Result from a retraining run."""
|
| 46 |
+
status: str # "success", "skipped", "error"
|
| 47 |
+
models_updated: List[str] = field(default_factory=list)
|
| 48 |
+
samples_used: int = 0
|
| 49 |
+
duration_seconds: float = 0.0
|
| 50 |
+
accuracy_delta: Dict[str, Optional[float]] = field(default_factory=dict)
|
| 51 |
+
next_retrain_hint: Dict = field(default_factory=dict)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class RetrainingService:
|
| 55 |
+
"""
|
| 56 |
+
Orchestrates incremental retraining for all 3 ML models.
|
| 57 |
+
Called by POST /retrain endpoint.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
bert_classifier,
|
| 63 |
+
gnn_inference,
|
| 64 |
+
cnn_inference,
|
| 65 |
+
) -> None:
|
| 66 |
+
self._bert = bert_classifier
|
| 67 |
+
self._gnn = gnn_inference
|
| 68 |
+
self._cnn = cnn_inference
|
| 69 |
+
self._model_version = self._load_version()
|
| 70 |
+
|
| 71 |
+
def _load_version(self) -> int:
|
| 72 |
+
"""Load current model version from disk."""
|
| 73 |
+
MODEL_VERSION_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 74 |
+
if MODEL_VERSION_PATH.exists():
|
| 75 |
+
try:
|
| 76 |
+
data = json.loads(MODEL_VERSION_PATH.read_text())
|
| 77 |
+
return data.get("version", 0)
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
return 0
|
| 81 |
+
|
| 82 |
+
def _save_version(self, accuracy_delta: Dict[str, Optional[float]]) -> None:
|
| 83 |
+
"""Save updated model version to disk."""
|
| 84 |
+
self._model_version += 1
|
| 85 |
+
data = {
|
| 86 |
+
"version": self._model_version,
|
| 87 |
+
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
| 88 |
+
"accuracy": accuracy_delta,
|
| 89 |
+
}
|
| 90 |
+
MODEL_VERSION_PATH.write_text(json.dumps(data, indent=2))
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def model_version(self) -> int:
|
| 94 |
+
return self._model_version
|
| 95 |
+
|
| 96 |
+
def get_version_info(self) -> dict:
|
| 97 |
+
"""Get current model version info for GET /model_version."""
|
| 98 |
+
if MODEL_VERSION_PATH.exists():
|
| 99 |
+
try:
|
| 100 |
+
return json.loads(MODEL_VERSION_PATH.read_text())
|
| 101 |
+
except Exception:
|
| 102 |
+
pass
|
| 103 |
+
return {
|
| 104 |
+
"version": self._model_version,
|
| 105 |
+
"updated_at": None,
|
| 106 |
+
"accuracy": {},
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
async def retrain(
|
| 110 |
+
self,
|
| 111 |
+
samples: List[FeedbackRecord],
|
| 112 |
+
) -> RetrainResult:
|
| 113 |
+
"""
|
| 114 |
+
Perform incremental retraining on all models.
|
| 115 |
+
|
| 116 |
+
Steps:
|
| 117 |
+
1. Validate samples (min 10, URL format check)
|
| 118 |
+
2. Separate by tier_used for targeted updates
|
| 119 |
+
3. Run BERT + GNN updates in parallel
|
| 120 |
+
4. Run CNN update if Tier 4 samples exist
|
| 121 |
+
5. Compute accuracy_delta for each model
|
| 122 |
+
6. Increment model version
|
| 123 |
+
7. Hot-reload all models
|
| 124 |
+
|
| 125 |
+
Returns RetrainResult with status and deltas.
|
| 126 |
+
"""
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
|
| 129 |
+
# 1. Validate
|
| 130 |
+
valid_samples = self._validate_samples(samples)
|
| 131 |
+
if len(valid_samples) < 10:
|
| 132 |
+
return RetrainResult(
|
| 133 |
+
status="skipped",
|
| 134 |
+
samples_used=len(valid_samples),
|
| 135 |
+
next_retrain_hint={
|
| 136 |
+
"recommended_trigger": "count",
|
| 137 |
+
"min_samples_needed": 10 - len(valid_samples),
|
| 138 |
+
},
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# 2. Convert to (url, label) pairs
|
| 142 |
+
url_label_pairs: List[Tuple[str, int]] = []
|
| 143 |
+
tier4_pairs: List[Tuple[str, int]] = []
|
| 144 |
+
|
| 145 |
+
for sample in valid_samples:
|
| 146 |
+
# Determine the true label based on user feedback
|
| 147 |
+
if sample.user_feedback == "correct":
|
| 148 |
+
label = 1 if sample.verdict == "phishing" else 0
|
| 149 |
+
elif sample.user_feedback == "incorrect":
|
| 150 |
+
label = 0 if sample.verdict == "phishing" else 1
|
| 151 |
+
else:
|
| 152 |
+
continue
|
| 153 |
+
|
| 154 |
+
url_label_pairs.append((sample.url, label))
|
| 155 |
+
if sample.tier_used == 4:
|
| 156 |
+
tier4_pairs.append((sample.url, label))
|
| 157 |
+
|
| 158 |
+
if len(url_label_pairs) < 5:
|
| 159 |
+
return RetrainResult(
|
| 160 |
+
status="skipped",
|
| 161 |
+
samples_used=len(url_label_pairs),
|
| 162 |
+
next_retrain_hint={
|
| 163 |
+
"recommended_trigger": "count",
|
| 164 |
+
"min_samples_needed": 5,
|
| 165 |
+
},
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# 3. Run updates
|
| 169 |
+
models_updated: List[str] = []
|
| 170 |
+
accuracy_delta: Dict[str, Optional[float]] = {}
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# BERT + GNN in parallel
|
| 174 |
+
loop = asyncio.get_event_loop()
|
| 175 |
+
|
| 176 |
+
bert_task = loop.run_in_executor(
|
| 177 |
+
None,
|
| 178 |
+
self._bert.incremental_update,
|
| 179 |
+
url_label_pairs,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
gnn_task = loop.run_in_executor(
|
| 183 |
+
None,
|
| 184 |
+
self._gnn.incremental_update,
|
| 185 |
+
url_label_pairs,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
bert_delta, gnn_delta = await asyncio.gather(
|
| 189 |
+
bert_task, gnn_task,
|
| 190 |
+
return_exceptions=True,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Process BERT result
|
| 194 |
+
if isinstance(bert_delta, Exception):
|
| 195 |
+
logger.error(f"BERT update error: {bert_delta}")
|
| 196 |
+
accuracy_delta["bert"] = None
|
| 197 |
+
elif bert_delta is not None:
|
| 198 |
+
accuracy_delta["bert"] = bert_delta
|
| 199 |
+
models_updated.append("bert")
|
| 200 |
+
else:
|
| 201 |
+
accuracy_delta["bert"] = None
|
| 202 |
+
|
| 203 |
+
# Process GNN result
|
| 204 |
+
if isinstance(gnn_delta, Exception):
|
| 205 |
+
logger.error(f"GNN update error: {gnn_delta}")
|
| 206 |
+
accuracy_delta["gnn"] = None
|
| 207 |
+
elif gnn_delta is not None:
|
| 208 |
+
accuracy_delta["gnn"] = gnn_delta
|
| 209 |
+
models_updated.append("gnn")
|
| 210 |
+
else:
|
| 211 |
+
accuracy_delta["gnn"] = None
|
| 212 |
+
|
| 213 |
+
# 4. CNN update (only if Tier 4 samples exist)
|
| 214 |
+
if tier4_pairs:
|
| 215 |
+
try:
|
| 216 |
+
cnn_delta = await self._cnn.incremental_update(tier4_pairs)
|
| 217 |
+
if cnn_delta is not None:
|
| 218 |
+
accuracy_delta["cnn"] = cnn_delta
|
| 219 |
+
models_updated.append("cnn")
|
| 220 |
+
else:
|
| 221 |
+
accuracy_delta["cnn"] = None
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"CNN update error: {e}")
|
| 224 |
+
accuracy_delta["cnn"] = None
|
| 225 |
+
else:
|
| 226 |
+
accuracy_delta["cnn"] = None
|
| 227 |
+
|
| 228 |
+
# 5. Update version
|
| 229 |
+
if models_updated:
|
| 230 |
+
self._save_version(accuracy_delta)
|
| 231 |
+
|
| 232 |
+
# 6. Hot-reload
|
| 233 |
+
await self._hot_reload(models_updated)
|
| 234 |
+
|
| 235 |
+
duration = time.time() - start_time
|
| 236 |
+
|
| 237 |
+
return RetrainResult(
|
| 238 |
+
status="success" if models_updated else "skipped",
|
| 239 |
+
models_updated=models_updated,
|
| 240 |
+
samples_used=len(url_label_pairs),
|
| 241 |
+
duration_seconds=round(duration, 2),
|
| 242 |
+
accuracy_delta=accuracy_delta,
|
| 243 |
+
next_retrain_hint={
|
| 244 |
+
"recommended_trigger": "count",
|
| 245 |
+
"min_samples_needed": 10,
|
| 246 |
+
},
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logger.error(f"Retraining failed: {e}")
|
| 251 |
+
return RetrainResult(
|
| 252 |
+
status="error",
|
| 253 |
+
duration_seconds=round(time.time() - start_time, 2),
|
| 254 |
+
accuracy_delta=accuracy_delta,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
def _validate_samples(self, samples: List[FeedbackRecord]) -> List[FeedbackRecord]:
|
| 258 |
+
"""Validate and filter feedback samples."""
|
| 259 |
+
valid = []
|
| 260 |
+
for s in samples:
|
| 261 |
+
# Must have user feedback
|
| 262 |
+
if not s.user_feedback:
|
| 263 |
+
continue
|
| 264 |
+
if s.user_feedback not in ("correct", "incorrect"):
|
| 265 |
+
continue
|
| 266 |
+
# Must have a valid URL
|
| 267 |
+
if not s.url or not s.url.startswith(("http://", "https://")):
|
| 268 |
+
continue
|
| 269 |
+
valid.append(s)
|
| 270 |
+
return valid
|
| 271 |
+
|
| 272 |
+
async def _hot_reload(self, models: List[str]) -> None:
|
| 273 |
+
"""Hot-reload updated models in-memory."""
|
| 274 |
+
if "bert" in models:
|
| 275 |
+
try:
|
| 276 |
+
bert_weights = Path(__file__).parent / "bert_weights"
|
| 277 |
+
if bert_weights.exists():
|
| 278 |
+
self._bert.load_local(bert_weights)
|
| 279 |
+
logger.info("BERT hot-reloaded")
|
| 280 |
+
except Exception as e:
|
| 281 |
+
logger.error(f"BERT hot-reload failed: {e}")
|
| 282 |
+
|
| 283 |
+
if "gnn" in models:
|
| 284 |
+
try:
|
| 285 |
+
self._gnn.reload()
|
| 286 |
+
logger.info("GNN hot-reloaded")
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error(f"GNN hot-reload failed: {e}")
|
| 289 |
+
|
| 290 |
+
if "cnn" in models:
|
| 291 |
+
try:
|
| 292 |
+
self._cnn.reload()
|
| 293 |
+
logger.info("CNN hot-reloaded")
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"CNN hot-reload failed: {e}")
|
screenshot_collector.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - screenshot_collector.py
|
| 3 |
+
# Batch screenshot capture for CNN training data generation.
|
| 4 |
+
#
|
| 5 |
+
# Uses Playwright async API with 10 concurrent captures.
|
| 6 |
+
# Blocks fonts, media, video for 60-70% speedup.
|
| 7 |
+
# Saves PNG named by URL SHA256 hash.
|
| 8 |
+
# ============================================================
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import asyncio
|
| 13 |
+
import hashlib
|
| 14 |
+
import logging
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=logging.INFO,
|
| 21 |
+
format="%(asctime)s | %(levelname)-7s | %(message)s",
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger("phishguard.screenshot_collector")
|
| 24 |
+
|
| 25 |
+
BACKEND_DIR = Path(__file__).parent
|
| 26 |
+
DATA_DIR = BACKEND_DIR / "data"
|
| 27 |
+
SCREENSHOTS_DIR = DATA_DIR / "screenshots"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def url_to_filename(url: str) -> str:
|
| 31 |
+
"""Convert URL to safe filename using SHA256 hash."""
|
| 32 |
+
url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
|
| 33 |
+
return f"{url_hash}.png"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def capture_single(
|
| 37 |
+
url: str,
|
| 38 |
+
save_dir: Path,
|
| 39 |
+
semaphore: asyncio.Semaphore,
|
| 40 |
+
browser,
|
| 41 |
+
) -> bool:
|
| 42 |
+
"""Capture a single screenshot with concurrency limiting."""
|
| 43 |
+
async with semaphore:
|
| 44 |
+
filename = url_to_filename(url)
|
| 45 |
+
filepath = save_dir / filename
|
| 46 |
+
|
| 47 |
+
# Skip if already captured
|
| 48 |
+
if filepath.exists():
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
page = await browser.new_page(
|
| 53 |
+
viewport={"width": 1280, "height": 800},
|
| 54 |
+
user_agent=(
|
| 55 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 56 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 57 |
+
"Chrome/120.0.0.0 Safari/537.36"
|
| 58 |
+
),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Block heavy resources for speed
|
| 62 |
+
await page.route(
|
| 63 |
+
"**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,avi,mp3,wav,flac}",
|
| 64 |
+
lambda route: route.abort(),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
await page.goto(
|
| 68 |
+
url,
|
| 69 |
+
wait_until="domcontentloaded",
|
| 70 |
+
timeout=10000,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Brief wait for rendering
|
| 74 |
+
await asyncio.sleep(0.5)
|
| 75 |
+
|
| 76 |
+
screenshot = await page.screenshot(type="png")
|
| 77 |
+
filepath.write_bytes(screenshot)
|
| 78 |
+
|
| 79 |
+
await page.close()
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.debug(f"Screenshot failed for {url}: {e}")
|
| 84 |
+
try:
|
| 85 |
+
await page.close()
|
| 86 |
+
except Exception:
|
| 87 |
+
pass
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
async def batch_capture(
|
| 92 |
+
urls: List[str],
|
| 93 |
+
save_dir: Path,
|
| 94 |
+
concurrency: int = 10,
|
| 95 |
+
label: str = "urls",
|
| 96 |
+
) -> int:
|
| 97 |
+
"""
|
| 98 |
+
Capture screenshots for a batch of URLs concurrently.
|
| 99 |
+
Returns count of successful captures.
|
| 100 |
+
"""
|
| 101 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
from playwright.async_api import async_playwright
|
| 105 |
+
except ImportError:
|
| 106 |
+
logger.error("Playwright not installed. Run: pip install playwright && playwright install chromium")
|
| 107 |
+
return 0
|
| 108 |
+
|
| 109 |
+
semaphore = asyncio.Semaphore(concurrency)
|
| 110 |
+
success_count = 0
|
| 111 |
+
|
| 112 |
+
async with async_playwright() as p:
|
| 113 |
+
browser = await p.chromium.launch(headless=True)
|
| 114 |
+
|
| 115 |
+
tasks = [
|
| 116 |
+
capture_single(url, save_dir, semaphore, browser)
|
| 117 |
+
for url in urls
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 121 |
+
|
| 122 |
+
for i, result in enumerate(results):
|
| 123 |
+
if result is True:
|
| 124 |
+
success_count += 1
|
| 125 |
+
if (i + 1) % 50 == 0:
|
| 126 |
+
logger.info(f" {label}: {i+1}/{len(urls)} processed ({success_count} captured)")
|
| 127 |
+
|
| 128 |
+
await browser.close()
|
| 129 |
+
|
| 130 |
+
logger.info(f" {label}: {success_count}/{len(urls)} screenshots captured")
|
| 131 |
+
return success_count
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
async def collect_training_screenshots(
|
| 135 |
+
phish_count: int = 10,
|
| 136 |
+
legit_count: int = 10,
|
| 137 |
+
) -> None:
|
| 138 |
+
"""Collect screenshots for CNN training."""
|
| 139 |
+
from data_collector import download_phishtank, download_tranco
|
| 140 |
+
|
| 141 |
+
phishing_dir = SCREENSHOTS_DIR / "phishing"
|
| 142 |
+
legitimate_dir = SCREENSHOTS_DIR / "legitimate"
|
| 143 |
+
|
| 144 |
+
# Download URL lists
|
| 145 |
+
print("π₯ Loading URL lists...")
|
| 146 |
+
phish_urls = download_phishtank(max_urls=phish_count)[:phish_count]
|
| 147 |
+
legit_urls = download_tranco(n=legit_count)[:legit_count]
|
| 148 |
+
|
| 149 |
+
print(f"\nπΈ Capturing phishing screenshots ({len(phish_urls)} URLs)...")
|
| 150 |
+
phish_success = await batch_capture(phish_urls, phishing_dir, label="Phishing")
|
| 151 |
+
|
| 152 |
+
print(f"\nπΈ Capturing legitimate screenshots ({len(legit_urls)} URLs)...")
|
| 153 |
+
legit_success = await batch_capture(legit_urls, legitimate_dir, label="Legitimate")
|
| 154 |
+
|
| 155 |
+
print(f"\nβ
Screenshots collected:")
|
| 156 |
+
print(f" Phishing: {phish_success}/{len(phish_urls)}")
|
| 157 |
+
print(f" Legitimate: {legit_success}/{len(legit_urls)}")
|
| 158 |
+
print(f" Saved to: {SCREENSHOTS_DIR}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def main() -> None:
|
| 162 |
+
print("=" * 60)
|
| 163 |
+
print("PhishGuard AI οΏ½οΏ½οΏ½ Screenshot Collection")
|
| 164 |
+
print("=" * 60)
|
| 165 |
+
|
| 166 |
+
asyncio.run(collect_training_screenshots())
|
| 167 |
+
|
| 168 |
+
print("=" * 60)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
main()
|
screenshot_hasher.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - cnn/screenshot_hasher.py
|
| 3 |
+
# Perceptual hash-based brand impersonation detector.
|
| 4 |
+
#
|
| 5 |
+
# Compares webpage screenshots against reference hashes of
|
| 6 |
+
# known brand login pages using imagehash.phash.
|
| 7 |
+
#
|
| 8 |
+
# brand_boost = 0.25 if hamming_distance < 10 else 0.0
|
| 9 |
+
# ============================================================
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import io
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Tuple, Optional, Dict, List
|
| 18 |
+
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger("phishguard.cnn.hasher")
|
| 22 |
+
|
| 23 |
+
# ββ Try to use imagehash, fall back to custom implementation βββββββββ
|
| 24 |
+
_imagehash_available = False
|
| 25 |
+
try:
|
| 26 |
+
import imagehash
|
| 27 |
+
_imagehash_available = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
logger.info("imagehash not installed β using built-in phash")
|
| 30 |
+
|
| 31 |
+
HASH_DB_PATH = Path(__file__).parent / "brand_hashes.json"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class BrandHashDetector:
|
| 35 |
+
"""
|
| 36 |
+
Perceptual hash-based brand impersonation detector.
|
| 37 |
+
Compares screenshots against reference hashes of 10 major brands.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
BRANDS: List[str] = [
|
| 41 |
+
"paypal", "google", "apple", "microsoft", "amazon",
|
| 42 |
+
"chase", "netflix", "facebook", "instagram", "wellsfargo",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
BRAND_DOMAINS: Dict[str, str] = {
|
| 46 |
+
"paypal": "paypal.com",
|
| 47 |
+
"google": "google.com",
|
| 48 |
+
"apple": "apple.com",
|
| 49 |
+
"microsoft": "microsoft.com",
|
| 50 |
+
"amazon": "amazon.com",
|
| 51 |
+
"chase": "chase.com",
|
| 52 |
+
"netflix": "netflix.com",
|
| 53 |
+
"facebook": "facebook.com",
|
| 54 |
+
"instagram": "instagram.com",
|
| 55 |
+
"wellsfargo": "wellsfargo.com",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def __init__(self, hash_db_path: Optional[Path] = None) -> None:
|
| 59 |
+
self._hash_db_path = hash_db_path or HASH_DB_PATH
|
| 60 |
+
self._reference_hashes: Dict[str, dict] = {}
|
| 61 |
+
self._load_reference_hashes()
|
| 62 |
+
|
| 63 |
+
def _load_reference_hashes(self) -> None:
|
| 64 |
+
"""Load reference hashes from JSON database."""
|
| 65 |
+
if self._hash_db_path.exists():
|
| 66 |
+
try:
|
| 67 |
+
with open(self._hash_db_path) as f:
|
| 68 |
+
self._reference_hashes = json.load(f)
|
| 69 |
+
logger.info(f"Loaded {len(self._reference_hashes)} brand hashes")
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.warning(f"Failed to load brand hashes: {e}")
|
| 72 |
+
self._reference_hashes = {}
|
| 73 |
+
else:
|
| 74 |
+
logger.info("No brand hash DB found β brand detection disabled")
|
| 75 |
+
self._reference_hashes = {}
|
| 76 |
+
|
| 77 |
+
def compute_hash(self, img_bytes: bytes, hash_size: int = 16) -> Optional[int]:
|
| 78 |
+
"""
|
| 79 |
+
Compute perceptual hash of an image.
|
| 80 |
+
Uses imagehash.phash if available, otherwise custom DCT-less implementation.
|
| 81 |
+
"""
|
| 82 |
+
try:
|
| 83 |
+
img = Image.open(io.BytesIO(img_bytes))
|
| 84 |
+
|
| 85 |
+
if _imagehash_available:
|
| 86 |
+
h = imagehash.phash(img, hash_size=hash_size)
|
| 87 |
+
return int(str(h), 16)
|
| 88 |
+
else:
|
| 89 |
+
return self._custom_phash(img, hash_size)
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.warning(f"Hash computation failed: {e}")
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
def _custom_phash(self, img: Image.Image, hash_size: int = 16) -> int:
|
| 96 |
+
"""Fallback perceptual hash (mean-based, no DCT)."""
|
| 97 |
+
img = img.convert("L").resize((hash_size, hash_size), Image.LANCZOS)
|
| 98 |
+
pixels = list(img.getdata())
|
| 99 |
+
avg = sum(pixels) / len(pixels)
|
| 100 |
+
bits = "".join("1" if p > avg else "0" for p in pixels)
|
| 101 |
+
return int(bits, 2)
|
| 102 |
+
|
| 103 |
+
def hamming_distance(self, h1: int, h2: int) -> int:
|
| 104 |
+
"""Count bit differences between two hashes. 0 = identical."""
|
| 105 |
+
return bin(h1 ^ h2).count("1")
|
| 106 |
+
|
| 107 |
+
def detect(
|
| 108 |
+
self,
|
| 109 |
+
screenshot_bytes: bytes,
|
| 110 |
+
url: str = "",
|
| 111 |
+
threshold: int = 10,
|
| 112 |
+
) -> Tuple[bool, str, float]:
|
| 113 |
+
"""
|
| 114 |
+
Detect brand impersonation from screenshot.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
(is_impersonation, brand_name, confidence)
|
| 118 |
+
is_impersonation: True if page looks like a brand but URL doesn't match
|
| 119 |
+
brand_name: detected brand name or ""
|
| 120 |
+
confidence: 0.0-1.0 similarity score
|
| 121 |
+
"""
|
| 122 |
+
page_hash = self.compute_hash(screenshot_bytes)
|
| 123 |
+
if page_hash is None:
|
| 124 |
+
return False, "", 0.0
|
| 125 |
+
|
| 126 |
+
url_lower = url.lower()
|
| 127 |
+
best_match: Optional[str] = None
|
| 128 |
+
best_distance = 999
|
| 129 |
+
best_confidence = 0.0
|
| 130 |
+
|
| 131 |
+
for brand, entry in self._reference_hashes.items():
|
| 132 |
+
try:
|
| 133 |
+
stored_hash = int(entry["hash"])
|
| 134 |
+
distance = self.hamming_distance(page_hash, stored_hash)
|
| 135 |
+
confidence = max(0.0, 1.0 - distance / 256.0)
|
| 136 |
+
|
| 137 |
+
if distance < best_distance:
|
| 138 |
+
best_distance = distance
|
| 139 |
+
best_match = brand
|
| 140 |
+
best_confidence = confidence
|
| 141 |
+
|
| 142 |
+
except (ValueError, KeyError):
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
if best_match and best_distance <= threshold:
|
| 146 |
+
legit_domain = self.BRAND_DOMAINS.get(best_match, f"{best_match}.com")
|
| 147 |
+
|
| 148 |
+
# Check if URL belongs to legitimate domain
|
| 149 |
+
if legit_domain not in url_lower:
|
| 150 |
+
return True, best_match, best_confidence
|
| 151 |
+
else:
|
| 152 |
+
return False, best_match, best_confidence
|
| 153 |
+
|
| 154 |
+
return False, "", 0.0
|
| 155 |
+
|
| 156 |
+
def register_brand(
|
| 157 |
+
self,
|
| 158 |
+
brand_name: str,
|
| 159 |
+
domain: str,
|
| 160 |
+
screenshot_bytes: bytes,
|
| 161 |
+
) -> bool:
|
| 162 |
+
"""Register a brand's reference screenshot hash."""
|
| 163 |
+
h = self.compute_hash(screenshot_bytes)
|
| 164 |
+
if h is None:
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
self._reference_hashes[brand_name] = {
|
| 168 |
+
"domain": domain,
|
| 169 |
+
"hash": str(h),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
# Save to disk
|
| 173 |
+
try:
|
| 174 |
+
with open(self._hash_db_path, "w") as f:
|
| 175 |
+
json.dump(self._reference_hashes, f, indent=2)
|
| 176 |
+
logger.info(f"Registered brand: {brand_name} ({domain})")
|
| 177 |
+
return True
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"Failed to save brand hash: {e}")
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ββ Legacy compatibility βββββββββββββββββββββββββββββββββββββββββββββ
|
| 184 |
+
_detector = BrandHashDetector()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def check_brand_impersonation(
|
| 188 |
+
screenshot_bytes: bytes,
|
| 189 |
+
url: str,
|
| 190 |
+
similarity_threshold: int = 10,
|
| 191 |
+
) -> dict:
|
| 192 |
+
"""Legacy wrapper for backward compatibility."""
|
| 193 |
+
is_impersonation, brand, confidence = _detector.detect(
|
| 194 |
+
screenshot_bytes, url, similarity_threshold,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if is_impersonation:
|
| 198 |
+
return {
|
| 199 |
+
"impersonation_detected": True,
|
| 200 |
+
"impersonated_brand": brand,
|
| 201 |
+
"legitimate_domain": _detector.BRAND_DOMAINS.get(brand, ""),
|
| 202 |
+
"visual_similarity": round(confidence, 3),
|
| 203 |
+
}
|
| 204 |
+
elif brand:
|
| 205 |
+
return {
|
| 206 |
+
"impersonation_detected": False,
|
| 207 |
+
"matched_brand": brand,
|
| 208 |
+
"note": "legitimate site",
|
| 209 |
+
}
|
| 210 |
+
else:
|
| 211 |
+
return {
|
| 212 |
+
"impersonation_detected": False,
|
| 213 |
+
"reason": "no_brand_match",
|
| 214 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cls_token": {
|
| 3 |
+
"content": "[CLS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"mask_token": {
|
| 10 |
+
"content": "[MASK]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"sep_token": {
|
| 24 |
+
"content": "[SEP]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"unk_token": {
|
| 31 |
+
"content": "[UNK]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
}
|
| 37 |
+
}
|
test_endpoint.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from main import analyze_email_endpoint, EmailRequest
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
async def run_tests():
|
| 6 |
+
print("--- Testing Tier 1: Whitelist ---")
|
| 7 |
+
res1 = await analyze_email_endpoint(EmailRequest(
|
| 8 |
+
sender="noreply@github.com",
|
| 9 |
+
subject="Your receipt",
|
| 10 |
+
body="...",
|
| 11 |
+
urls=[]
|
| 12 |
+
))
|
| 13 |
+
print(json.dumps(res1, indent=2))
|
| 14 |
+
|
| 15 |
+
print("\n--- Testing Tier 2: Text Heuristic (No URLs) ---")
|
| 16 |
+
res2 = await analyze_email_endpoint(EmailRequest(
|
| 17 |
+
sender="admin@unknown-domain.com",
|
| 18 |
+
subject="URGENT: Password Reset Required",
|
| 19 |
+
body="Please reset immediately.",
|
| 20 |
+
urls=[]
|
| 21 |
+
))
|
| 22 |
+
print(json.dumps(res2, indent=2))
|
| 23 |
+
|
| 24 |
+
print("\n--- Testing Tier 2: Async URLs ---")
|
| 25 |
+
res3 = await analyze_email_endpoint(EmailRequest(
|
| 26 |
+
sender="service@paypal-update.net",
|
| 27 |
+
subject="Important Account Update",
|
| 28 |
+
body="Click the link to verify your account.",
|
| 29 |
+
urls=["http://chase-bank-verify-login.cx/auth", "http://google.com/"]
|
| 30 |
+
))
|
| 31 |
+
print(json.dumps(res3, indent=2))
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
asyncio.run(run_tests())
|
tier3_bert_gnn.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - tier3_bert_gnn.py
|
| 3 |
+
# Tier 3: BERT + GNN Parallel Ensemble
|
| 4 |
+
#
|
| 5 |
+
# Triggered only when Tier 2 score < 80.
|
| 6 |
+
# BERT and GNN run in PARALLEL via asyncio.gather + run_in_executor.
|
| 7 |
+
#
|
| 8 |
+
# Ensemble formula:
|
| 9 |
+
# P3 = 0.45Β·P_bert + 0.35Β·P_gnn + 0.20Β·(H_score/100)
|
| 10 |
+
#
|
| 11 |
+
# Decision:
|
| 12 |
+
# P3 >= 0.85 β BLOCK
|
| 13 |
+
# P3 < 0.40 β SAFE
|
| 14 |
+
# 0.40 <= P3 < 0.85 β escalate to Tier 4
|
| 15 |
+
# ============================================================
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import asyncio
|
| 20 |
+
import logging
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("phishguard.tier3")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Tier3Ensemble:
|
| 27 |
+
"""
|
| 28 |
+
Tier 3: BERT + GNN parallel ensemble classifier.
|
| 29 |
+
|
| 30 |
+
Runs BERT and GNN inference in parallel using asyncio.gather
|
| 31 |
+
with run_in_executor for non-blocking thread pool execution.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# Ensemble weights
|
| 35 |
+
W_BERT: float = 0.45
|
| 36 |
+
W_GNN: float = 0.35
|
| 37 |
+
W_HEURISTIC: float = 0.20
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
bert_classifier,
|
| 42 |
+
gnn_inference,
|
| 43 |
+
) -> None:
|
| 44 |
+
self._bert = bert_classifier
|
| 45 |
+
self._gnn = gnn_inference
|
| 46 |
+
|
| 47 |
+
async def predict(
|
| 48 |
+
self,
|
| 49 |
+
url: str,
|
| 50 |
+
title: str = "",
|
| 51 |
+
snippet: str = "",
|
| 52 |
+
h_score: int = 0,
|
| 53 |
+
) -> float:
|
| 54 |
+
"""
|
| 55 |
+
Run BERT + GNN in parallel and compute ensemble score.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
url: The URL to analyze
|
| 59 |
+
title: Page title (optional)
|
| 60 |
+
snippet: Page content snippet (optional)
|
| 61 |
+
h_score: Heuristic score from Tier 2 (0-100, passed through, NOT recomputed)
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
P3 β [0,1] β ensemble phishing probability
|
| 65 |
+
"""
|
| 66 |
+
loop = asyncio.get_event_loop()
|
| 67 |
+
|
| 68 |
+
# Run BERT and GNN in parallel (both are CPU-bound, use thread pool)
|
| 69 |
+
bert_task = self._bert_predict(url, title, snippet, loop)
|
| 70 |
+
gnn_task = self._gnn_predict(url, loop)
|
| 71 |
+
|
| 72 |
+
p_bert, p_gnn = await asyncio.gather(bert_task, gnn_task)
|
| 73 |
+
|
| 74 |
+
# Ensemble: P3 = 0.45Β·P_bert + 0.35Β·P_gnn + 0.20Β·H_norm
|
| 75 |
+
h_norm = h_score / 100.0
|
| 76 |
+
p3 = (self.W_BERT * p_bert) + (self.W_GNN * p_gnn) + (self.W_HEURISTIC * h_norm)
|
| 77 |
+
|
| 78 |
+
logger.info(
|
| 79 |
+
f"Tier3 ensemble | url={url[:60]} | "
|
| 80 |
+
f"P_bert={p_bert:.4f} P_gnn={p_gnn:.4f} H_norm={h_norm:.4f} β P3={p3:.4f}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return round(min(max(p3, 0.0), 1.0), 4)
|
| 84 |
+
|
| 85 |
+
async def _bert_predict(
|
| 86 |
+
self,
|
| 87 |
+
url: str,
|
| 88 |
+
title: str,
|
| 89 |
+
snippet: str,
|
| 90 |
+
loop: asyncio.AbstractEventLoop,
|
| 91 |
+
) -> float:
|
| 92 |
+
"""
|
| 93 |
+
Run BERT inference in thread pool (non-blocking).
|
| 94 |
+
Returns P_bert β [0,1].
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
p_bert = await asyncio.wait_for(
|
| 98 |
+
loop.run_in_executor(
|
| 99 |
+
None, # Default thread pool
|
| 100 |
+
self._bert.predict,
|
| 101 |
+
url,
|
| 102 |
+
title,
|
| 103 |
+
snippet,
|
| 104 |
+
),
|
| 105 |
+
timeout=10.0,
|
| 106 |
+
)
|
| 107 |
+
return float(p_bert)
|
| 108 |
+
except asyncio.TimeoutError:
|
| 109 |
+
logger.warning(f"BERT timeout for {url[:50]}")
|
| 110 |
+
return 0.5 # Neutral on timeout
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"BERT predict error: {e}")
|
| 113 |
+
return 0.5
|
| 114 |
+
|
| 115 |
+
async def _gnn_predict(
|
| 116 |
+
self,
|
| 117 |
+
url: str,
|
| 118 |
+
loop: asyncio.AbstractEventLoop,
|
| 119 |
+
) -> float:
|
| 120 |
+
"""
|
| 121 |
+
Run GNN inference in thread pool (non-blocking).
|
| 122 |
+
Returns P_gnn β [0,1].
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
p_gnn = await asyncio.wait_for(
|
| 126 |
+
loop.run_in_executor(
|
| 127 |
+
None,
|
| 128 |
+
self._gnn.predict,
|
| 129 |
+
url,
|
| 130 |
+
None, # related_urls
|
| 131 |
+
),
|
| 132 |
+
timeout=5.0,
|
| 133 |
+
)
|
| 134 |
+
return float(p_gnn)
|
| 135 |
+
except asyncio.TimeoutError:
|
| 136 |
+
logger.warning(f"GNN timeout for {url[:50]}")
|
| 137 |
+
return 0.5
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"GNN predict error: {e}")
|
| 140 |
+
return 0.5
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def decide(p3: float) -> str:
|
| 144 |
+
"""
|
| 145 |
+
Make decision based on P3 score.
|
| 146 |
+
Returns: 'block', 'safe', or 'escalate'
|
| 147 |
+
"""
|
| 148 |
+
if p3 >= 0.85:
|
| 149 |
+
return "block"
|
| 150 |
+
elif p3 < 0.40:
|
| 151 |
+
return "safe"
|
| 152 |
+
else:
|
| 153 |
+
return "escalate"
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"100": {
|
| 12 |
+
"content": "[UNK]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"101": {
|
| 20 |
+
"content": "[CLS]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"102": {
|
| 28 |
+
"content": "[SEP]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"103": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"clean_up_tokenization_spaces": true,
|
| 45 |
+
"cls_token": "[CLS]",
|
| 46 |
+
"do_lower_case": true,
|
| 47 |
+
"mask_token": "[MASK]",
|
| 48 |
+
"model_max_length": 512,
|
| 49 |
+
"pad_token": "[PAD]",
|
| 50 |
+
"sep_token": "[SEP]",
|
| 51 |
+
"strip_accents": null,
|
| 52 |
+
"tokenize_chinese_chars": true,
|
| 53 |
+
"tokenizer_class": "BertTokenizer",
|
| 54 |
+
"unk_token": "[UNK]"
|
| 55 |
+
}
|
train_cnn.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - cnn/train_cnn.py
|
| 3 |
+
# CNN fine-tuning script for phishing screenshot detection.
|
| 4 |
+
#
|
| 5 |
+
# Loads data/screenshots/ with ImageFolder structure
|
| 6 |
+
# Augmentation: RandomHorizontalFlip, ColorJitter, RandomRotation
|
| 7 |
+
# 15 epochs, AdamW on head only (backbone stays frozen)
|
| 8 |
+
# Saves cnn_weights.pt + cnn_replay_buffer.pt
|
| 9 |
+
# Works with as few as 100 images per class
|
| 10 |
+
# ============================================================
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=logging.INFO,
|
| 21 |
+
format="%(asctime)s | %(levelname)-7s | %(message)s",
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger("phishguard.cnn.train")
|
| 24 |
+
|
| 25 |
+
CNN_DIR = Path(__file__).parent
|
| 26 |
+
BACKEND_DIR = CNN_DIR.parent
|
| 27 |
+
WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt"
|
| 28 |
+
REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt"
|
| 29 |
+
SCREENSHOTS_DIR = BACKEND_DIR / "data" / "screenshots"
|
| 30 |
+
|
| 31 |
+
sys.path.insert(0, str(CNN_DIR))
|
| 32 |
+
sys.path.insert(0, str(BACKEND_DIR))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main() -> None:
|
| 36 |
+
print("=" * 60)
|
| 37 |
+
print("PhishGuard AI β CNN Training")
|
| 38 |
+
print("=" * 60)
|
| 39 |
+
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
from torch.optim import AdamW
|
| 43 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
| 44 |
+
import torchvision.transforms as T
|
| 45 |
+
from PIL import Image
|
| 46 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 47 |
+
|
| 48 |
+
from cnn_model import PhishCNN
|
| 49 |
+
|
| 50 |
+
# ββ Check data βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
phishing_dir = SCREENSHOTS_DIR / "phishing"
|
| 52 |
+
legitimate_dir = SCREENSHOTS_DIR / "legitimate"
|
| 53 |
+
|
| 54 |
+
if not phishing_dir.exists() or not legitimate_dir.exists():
|
| 55 |
+
print(f"\nβ οΈ Screenshot directories not found:")
|
| 56 |
+
print(f" Expected: {phishing_dir}")
|
| 57 |
+
print(f" Expected: {legitimate_dir}")
|
| 58 |
+
print(f"\n Run: python screenshot_collector.py")
|
| 59 |
+
|
| 60 |
+
# Create dirs and generate placeholder images for testing
|
| 61 |
+
phishing_dir.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
legitimate_dir.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
print(" Generating synthetic training images...")
|
| 65 |
+
_generate_synthetic_screenshots(phishing_dir, legitimate_dir)
|
| 66 |
+
|
| 67 |
+
phishing_files = list(phishing_dir.glob("*.png")) + list(phishing_dir.glob("*.jpg"))
|
| 68 |
+
legit_files = list(legitimate_dir.glob("*.png")) + list(legitimate_dir.glob("*.jpg"))
|
| 69 |
+
|
| 70 |
+
print(f"\nπ Dataset:")
|
| 71 |
+
print(f" Phishing screenshots: {len(phishing_files)}")
|
| 72 |
+
print(f" Legitimate screenshots: {len(legit_files)}")
|
| 73 |
+
|
| 74 |
+
if len(phishing_files) < 10 or len(legit_files) < 10:
|
| 75 |
+
print("β οΈ Too few screenshots. Generating synthetic images...")
|
| 76 |
+
_generate_synthetic_screenshots(phishing_dir, legitimate_dir, count=100)
|
| 77 |
+
phishing_files = list(phishing_dir.glob("*.png"))
|
| 78 |
+
legit_files = list(legitimate_dir.glob("*.png"))
|
| 79 |
+
print(f" Phishing: {len(phishing_files)}, Legitimate: {len(legit_files)}")
|
| 80 |
+
|
| 81 |
+
# ββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
train_transform = T.Compose([
|
| 83 |
+
T.Resize((224, 224)),
|
| 84 |
+
T.RandomHorizontalFlip(),
|
| 85 |
+
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
| 86 |
+
T.RandomRotation(5),
|
| 87 |
+
T.ToTensor(),
|
| 88 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
val_transform = T.Compose([
|
| 92 |
+
T.Resize((224, 224)),
|
| 93 |
+
T.ToTensor(),
|
| 94 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 95 |
+
])
|
| 96 |
+
|
| 97 |
+
class ScreenshotDataset(Dataset):
|
| 98 |
+
def __init__(self, files: List[Path], label: int, transform):
|
| 99 |
+
self.files = files
|
| 100 |
+
self.label = label
|
| 101 |
+
self.transform = transform
|
| 102 |
+
|
| 103 |
+
def __len__(self) -> int:
|
| 104 |
+
return len(self.files)
|
| 105 |
+
|
| 106 |
+
def __getitem__(self, idx: int):
|
| 107 |
+
try:
|
| 108 |
+
img = Image.open(self.files[idx]).convert("RGB")
|
| 109 |
+
tensor = self.transform(img)
|
| 110 |
+
return tensor, self.label
|
| 111 |
+
except Exception:
|
| 112 |
+
# Return black image on error
|
| 113 |
+
tensor = torch.zeros(3, 224, 224)
|
| 114 |
+
return tensor, self.label
|
| 115 |
+
|
| 116 |
+
# Split: 80% train, 20% val
|
| 117 |
+
import random
|
| 118 |
+
random.shuffle(phishing_files)
|
| 119 |
+
random.shuffle(legit_files)
|
| 120 |
+
|
| 121 |
+
phish_split = int(len(phishing_files) * 0.8)
|
| 122 |
+
legit_split = int(len(legit_files) * 0.8)
|
| 123 |
+
|
| 124 |
+
train_phish = phishing_files[:phish_split]
|
| 125 |
+
val_phish = phishing_files[phish_split:]
|
| 126 |
+
train_legit = legit_files[:legit_split]
|
| 127 |
+
val_legit = legit_files[legit_split:]
|
| 128 |
+
|
| 129 |
+
train_dataset = (
|
| 130 |
+
ScreenshotDataset(train_phish, 1, train_transform)
|
| 131 |
+
+ ScreenshotDataset(train_legit, 0, train_transform)
|
| 132 |
+
)
|
| 133 |
+
val_dataset = (
|
| 134 |
+
ScreenshotDataset(val_phish, 1, val_transform)
|
| 135 |
+
+ ScreenshotDataset(val_legit, 0, val_transform)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
|
| 139 |
+
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
|
| 140 |
+
|
| 141 |
+
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 143 |
+
print(f"\nπ€ Device: {device}")
|
| 144 |
+
|
| 145 |
+
model = PhishCNN(pretrained=True).to(device)
|
| 146 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 147 |
+
total = sum(p.numel() for p in model.parameters())
|
| 148 |
+
print(f" Parameters: {total:,} total, {trainable:,} trainable")
|
| 149 |
+
|
| 150 |
+
# Only optimize head parameters
|
| 151 |
+
head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad]
|
| 152 |
+
optimizer = AdamW(head_params, lr=1e-3, weight_decay=1e-4)
|
| 153 |
+
loss_fn = nn.BCELoss()
|
| 154 |
+
|
| 155 |
+
# ββ Training βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 156 |
+
EPOCHS = 2
|
| 157 |
+
best_val_acc = 0.0
|
| 158 |
+
|
| 159 |
+
print(f"\nποΈ Training for {EPOCHS} epochs...")
|
| 160 |
+
print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7}")
|
| 161 |
+
print(f" {'β'*5} | {'β'*8} | {'β'*9} | {'β'*7}")
|
| 162 |
+
|
| 163 |
+
for epoch in range(1, EPOCHS + 1):
|
| 164 |
+
# Train
|
| 165 |
+
model.train()
|
| 166 |
+
total_loss = 0.0
|
| 167 |
+
train_preds, train_labels = [], []
|
| 168 |
+
|
| 169 |
+
for batch_x, batch_y in train_loader:
|
| 170 |
+
batch_x = batch_x.to(device)
|
| 171 |
+
batch_y = batch_y.float().to(device)
|
| 172 |
+
|
| 173 |
+
optimizer.zero_grad()
|
| 174 |
+
output = model(batch_x).squeeze()
|
| 175 |
+
loss = loss_fn(output, batch_y)
|
| 176 |
+
loss.backward()
|
| 177 |
+
optimizer.step()
|
| 178 |
+
|
| 179 |
+
total_loss += loss.item()
|
| 180 |
+
preds = (output >= 0.5).int()
|
| 181 |
+
train_preds.extend(preds.cpu().tolist())
|
| 182 |
+
train_labels.extend(batch_y.int().cpu().tolist())
|
| 183 |
+
|
| 184 |
+
avg_loss = total_loss / max(len(train_loader), 1)
|
| 185 |
+
train_acc = accuracy_score(train_labels, train_preds) if train_labels else 0.0
|
| 186 |
+
|
| 187 |
+
# Validate
|
| 188 |
+
model.eval()
|
| 189 |
+
val_preds, val_labels = [], []
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
for batch_x, batch_y in val_loader:
|
| 192 |
+
batch_x = batch_x.to(device)
|
| 193 |
+
batch_y = batch_y.float().to(device)
|
| 194 |
+
output = model(batch_x).squeeze()
|
| 195 |
+
preds = (output >= 0.5).int()
|
| 196 |
+
val_preds.extend(preds.cpu().tolist())
|
| 197 |
+
val_labels.extend(batch_y.int().cpu().tolist())
|
| 198 |
+
|
| 199 |
+
val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0
|
| 200 |
+
|
| 201 |
+
if epoch % 3 == 0 or epoch == 1:
|
| 202 |
+
print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f}")
|
| 203 |
+
|
| 204 |
+
if val_acc > best_val_acc:
|
| 205 |
+
best_val_acc = val_acc
|
| 206 |
+
torch.save(model.state_dict(), WEIGHTS_PATH)
|
| 207 |
+
|
| 208 |
+
# ββ Final metrics ββββββββββββββββββββββββββββββββββββββββββββ
|
| 209 |
+
if val_labels:
|
| 210 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 211 |
+
val_labels, val_preds, average="binary", zero_division=0,
|
| 212 |
+
)
|
| 213 |
+
print(f"\nπ Final Validation:")
|
| 214 |
+
print(f" Accuracy: {best_val_acc:.4f}")
|
| 215 |
+
print(f" Precision: {precision:.4f}")
|
| 216 |
+
print(f" Recall: {recall:.4f}")
|
| 217 |
+
print(f" F1 Score: {f1:.4f}")
|
| 218 |
+
|
| 219 |
+
# ββ Save replay buffer βββββββββββββββββββββββββββββββββββββββ
|
| 220 |
+
all_paths = phishing_files + legit_files
|
| 221 |
+
replay_paths = [str(p) for p in all_paths[:100]]
|
| 222 |
+
replay_labels = [1] * min(len(phishing_files), 50) + [0] * min(len(legit_files), 50)
|
| 223 |
+
REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 224 |
+
torch.save({"paths": replay_paths, "labels": replay_labels}, REPLAY_BUFFER_PATH)
|
| 225 |
+
|
| 226 |
+
print(f"\nβ
CNN weights saved to: {WEIGHTS_PATH}")
|
| 227 |
+
print(f"πΎ Replay buffer saved: {len(replay_paths)} paths β {REPLAY_BUFFER_PATH}")
|
| 228 |
+
print("=" * 60)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _generate_synthetic_screenshots(
|
| 232 |
+
phishing_dir: Path,
|
| 233 |
+
legitimate_dir: Path,
|
| 234 |
+
count: int = 100,
|
| 235 |
+
) -> None:
|
| 236 |
+
"""Generate synthetic screenshots for training when real data unavailable."""
|
| 237 |
+
import random
|
| 238 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 239 |
+
|
| 240 |
+
for label, save_dir, colors in [
|
| 241 |
+
("phishing", phishing_dir, [(200, 50, 50), (180, 30, 30), (220, 80, 60)]),
|
| 242 |
+
("legitimate", legitimate_dir, [(50, 120, 200), (30, 100, 180), (60, 140, 220)]),
|
| 243 |
+
]:
|
| 244 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 245 |
+
existing = len(list(save_dir.glob("*.png")))
|
| 246 |
+
needed = max(0, count - existing)
|
| 247 |
+
|
| 248 |
+
for i in range(needed):
|
| 249 |
+
# Create varied synthetic images
|
| 250 |
+
w, h = 1280, 800
|
| 251 |
+
bg = random.choice(colors)
|
| 252 |
+
img = Image.new("RGB", (w, h), bg)
|
| 253 |
+
draw = ImageDraw.Draw(img)
|
| 254 |
+
|
| 255 |
+
# Add shapes
|
| 256 |
+
for _ in range(random.randint(5, 15)):
|
| 257 |
+
x1 = random.randint(0, w - 100)
|
| 258 |
+
y1 = random.randint(0, h - 100)
|
| 259 |
+
x2 = x1 + random.randint(50, 300)
|
| 260 |
+
y2 = y1 + random.randint(30, 200)
|
| 261 |
+
color = tuple(random.randint(0, 255) for _ in range(3))
|
| 262 |
+
draw.rectangle([x1, y1, x2, y2], fill=color)
|
| 263 |
+
|
| 264 |
+
# Add text-like rectangles
|
| 265 |
+
for _ in range(random.randint(3, 8)):
|
| 266 |
+
x = random.randint(100, w - 400)
|
| 267 |
+
y = random.randint(100, h - 100)
|
| 268 |
+
draw.rectangle([x, y, x + random.randint(100, 300), y + 20],
|
| 269 |
+
fill=(255, 255, 255))
|
| 270 |
+
|
| 271 |
+
img.save(save_dir / f"synthetic_{i:04d}.png")
|
| 272 |
+
|
| 273 |
+
logger.info(f"Generated synthetic screenshots in {phishing_dir.parent}")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
if __name__ == "__main__":
|
| 277 |
+
main()
|
train_gnn.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - gnn/train_gnn.py
|
| 3 |
+
# Full GNN training script.
|
| 4 |
+
#
|
| 5 |
+
# Downloads PhishTank bz2 + TRANCO zip + Kaggle CSV mirror
|
| 6 |
+
# Builds training graphs, 40 epochs, saves gnn_weights.pt
|
| 7 |
+
# 70/15/15 train/val/test split with stratification
|
| 8 |
+
# Saves replay buffer to gnn_replay_buffer.pt
|
| 9 |
+
# ============================================================
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
import random
|
| 15 |
+
import logging
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import List, Tuple
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=logging.INFO,
|
| 21 |
+
format="%(asctime)s | %(levelname)-7s | %(message)s",
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger("phishguard.gnn.train")
|
| 24 |
+
|
| 25 |
+
# Paths
|
| 26 |
+
GNN_DIR = Path(__file__).parent
|
| 27 |
+
BACKEND_DIR = GNN_DIR.parent
|
| 28 |
+
WEIGHTS_PATH = GNN_DIR / "gnn_weights.pt"
|
| 29 |
+
REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "gnn_replay_buffer.pt"
|
| 30 |
+
|
| 31 |
+
# Add backend to path for imports
|
| 32 |
+
sys.path.insert(0, str(BACKEND_DIR))
|
| 33 |
+
sys.path.insert(0, str(GNN_DIR))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main() -> None:
|
| 37 |
+
print("=" * 60)
|
| 38 |
+
print("PhishGuard AI β GNN Training")
|
| 39 |
+
print("=" * 60)
|
| 40 |
+
|
| 41 |
+
import torch
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 44 |
+
|
| 45 |
+
from domain_graph_builder import DomainGraphBuilder
|
| 46 |
+
from gnn_model import PhishGNN, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM
|
| 47 |
+
|
| 48 |
+
# ββ Download data ββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
from data_collector import download_phishtank, download_tranco, merge_datasets
|
| 50 |
+
|
| 51 |
+
print("\nπ₯ Downloading datasets...")
|
| 52 |
+
phish_urls = download_phishtank(max_urls=50)
|
| 53 |
+
legit_urls = download_tranco(n=50)
|
| 54 |
+
print(f" Phishing URLs: {len(phish_urls)}")
|
| 55 |
+
print(f" Legitimate URLs: {len(legit_urls)}")
|
| 56 |
+
|
| 57 |
+
train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls)
|
| 58 |
+
|
| 59 |
+
# ββ Build graphs βββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
+
builder = DomainGraphBuilder()
|
| 61 |
+
CHUNK_SIZE = 4 # Group URLs into small graphs
|
| 62 |
+
|
| 63 |
+
def build_dataset(data: List[Tuple[str, int]], desc: str) -> list:
|
| 64 |
+
"""Build graph dataset from (url, label) pairs."""
|
| 65 |
+
dataset = []
|
| 66 |
+
# Separate by label
|
| 67 |
+
phish = [url for url, label in data if label == 1]
|
| 68 |
+
legit = [url for url, label in data if label == 0]
|
| 69 |
+
|
| 70 |
+
for urls, label in [(phish, 1), (legit, 0)]:
|
| 71 |
+
for i in range(0, len(urls), CHUNK_SIZE):
|
| 72 |
+
chunk = urls[i : i + CHUNK_SIZE]
|
| 73 |
+
if not chunk:
|
| 74 |
+
continue
|
| 75 |
+
graph = builder.build_graph(chunk)
|
| 76 |
+
x = torch.tensor(graph["features"], dtype=torch.float)
|
| 77 |
+
edges = graph["edges"]
|
| 78 |
+
if edges and len(edges) > 0:
|
| 79 |
+
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
|
| 80 |
+
else:
|
| 81 |
+
# Self-loops for graphs with no edges
|
| 82 |
+
n = x.size(0)
|
| 83 |
+
edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1)
|
| 84 |
+
dataset.append({
|
| 85 |
+
"x": x,
|
| 86 |
+
"edge_index": edge_index,
|
| 87 |
+
"y": torch.tensor([float(label)]),
|
| 88 |
+
})
|
| 89 |
+
|
| 90 |
+
random.shuffle(dataset)
|
| 91 |
+
print(f" {desc}: {len(dataset)} graphs")
|
| 92 |
+
return dataset
|
| 93 |
+
|
| 94 |
+
print("\nπ¨ Building graphs...")
|
| 95 |
+
train_graphs = build_dataset(train_data, "Train")
|
| 96 |
+
val_graphs = build_dataset(val_data, "Val")
|
| 97 |
+
test_graphs = build_dataset(test_data, "Test")
|
| 98 |
+
|
| 99 |
+
# ββ Create model βββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 101 |
+
print(f"\nπ€ Device: {device}")
|
| 102 |
+
|
| 103 |
+
model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
|
| 104 |
+
model = model.to(device)
|
| 105 |
+
model_type = "GCN" if PYGEOM_AVAILABLE else "MLP"
|
| 106 |
+
print(f" Model: Phish{model_type}")
|
| 107 |
+
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 108 |
+
|
| 109 |
+
# ββ Training βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
EPOCHS = 2
|
| 111 |
+
LR = 0.001
|
| 112 |
+
WEIGHT_DECAY = 1e-4
|
| 113 |
+
|
| 114 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
|
| 115 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 116 |
+
optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6,
|
| 117 |
+
)
|
| 118 |
+
loss_fn = F.binary_cross_entropy
|
| 119 |
+
|
| 120 |
+
best_val_acc = 0.0
|
| 121 |
+
best_epoch = 0
|
| 122 |
+
|
| 123 |
+
print(f"\nποΈ Training for {EPOCHS} epochs...")
|
| 124 |
+
print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7} | {'LR':>10}")
|
| 125 |
+
print(f" {'β' * 5} | {'β' * 8} | {'β' * 9} | {'β' * 7} | {'β' * 10}")
|
| 126 |
+
|
| 127 |
+
for epoch in range(1, EPOCHS + 1):
|
| 128 |
+
# ββ Train ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
model.train()
|
| 130 |
+
total_loss = 0.0
|
| 131 |
+
train_preds = []
|
| 132 |
+
train_labels = []
|
| 133 |
+
|
| 134 |
+
random.shuffle(train_graphs)
|
| 135 |
+
for item in train_graphs:
|
| 136 |
+
x = item["x"].to(device)
|
| 137 |
+
ei = item["edge_index"].to(device)
|
| 138 |
+
y = item["y"].to(device)
|
| 139 |
+
|
| 140 |
+
optimizer.zero_grad()
|
| 141 |
+
out = model(x, ei)
|
| 142 |
+
loss = loss_fn(out.squeeze(), y.squeeze())
|
| 143 |
+
loss.backward()
|
| 144 |
+
optimizer.step()
|
| 145 |
+
|
| 146 |
+
total_loss += loss.item()
|
| 147 |
+
pred = 1 if out.squeeze().item() >= 0.5 else 0
|
| 148 |
+
train_preds.append(pred)
|
| 149 |
+
train_labels.append(int(y.item()))
|
| 150 |
+
|
| 151 |
+
avg_loss = total_loss / max(len(train_graphs), 1)
|
| 152 |
+
train_acc = accuracy_score(train_labels, train_preds)
|
| 153 |
+
|
| 154 |
+
# ββ Validate βββββββββββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
model.eval()
|
| 156 |
+
val_preds = []
|
| 157 |
+
val_labels = []
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
for item in val_graphs:
|
| 161 |
+
x = item["x"].to(device)
|
| 162 |
+
ei = item["edge_index"].to(device)
|
| 163 |
+
y = item["y"].to(device)
|
| 164 |
+
|
| 165 |
+
out = model(x, ei)
|
| 166 |
+
pred = 1 if out.squeeze().item() >= 0.5 else 0
|
| 167 |
+
val_preds.append(pred)
|
| 168 |
+
val_labels.append(int(y.item()))
|
| 169 |
+
|
| 170 |
+
val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0
|
| 171 |
+
scheduler.step(avg_loss)
|
| 172 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 173 |
+
|
| 174 |
+
# Print progress
|
| 175 |
+
if epoch % 5 == 0 or epoch == 1:
|
| 176 |
+
print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f} | {current_lr:>10.6f}")
|
| 177 |
+
|
| 178 |
+
# Save best model
|
| 179 |
+
if val_acc > best_val_acc:
|
| 180 |
+
best_val_acc = val_acc
|
| 181 |
+
best_epoch = epoch
|
| 182 |
+
torch.save(model.state_dict(), WEIGHTS_PATH)
|
| 183 |
+
|
| 184 |
+
print(f"\n Best val accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
|
| 185 |
+
|
| 186 |
+
# ββ Test βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 187 |
+
# Reload best weights
|
| 188 |
+
model.load_state_dict(
|
| 189 |
+
torch.load(WEIGHTS_PATH, map_location=device, weights_only=True)
|
| 190 |
+
)
|
| 191 |
+
model.eval()
|
| 192 |
+
|
| 193 |
+
test_preds = []
|
| 194 |
+
test_labels = []
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
for item in test_graphs:
|
| 197 |
+
x = item["x"].to(device)
|
| 198 |
+
ei = item["edge_index"].to(device)
|
| 199 |
+
y = item["y"].to(device)
|
| 200 |
+
|
| 201 |
+
out = model(x, ei)
|
| 202 |
+
pred = 1 if out.squeeze().item() >= 0.5 else 0
|
| 203 |
+
test_preds.append(pred)
|
| 204 |
+
test_labels.append(int(y.item()))
|
| 205 |
+
|
| 206 |
+
test_acc = accuracy_score(test_labels, test_preds) if test_labels else 0.0
|
| 207 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 208 |
+
test_labels, test_preds, average="binary", zero_division=0,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
print(f"\nπ Test Results:")
|
| 212 |
+
print(f" Accuracy: {test_acc:.4f}")
|
| 213 |
+
print(f" Precision: {precision:.4f}")
|
| 214 |
+
print(f" Recall: {recall:.4f}")
|
| 215 |
+
print(f" F1 Score: {f1:.4f}")
|
| 216 |
+
|
| 217 |
+
# ββ Save replay buffer βββββββββββββββββββββββββββββββββββββββ
|
| 218 |
+
REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 219 |
+
replay_buffer = train_graphs[:500] # Keep last 500 samples
|
| 220 |
+
torch.save(replay_buffer, REPLAY_BUFFER_PATH)
|
| 221 |
+
print(f"\nπΎ Replay buffer saved: {len(replay_buffer)} samples β {REPLAY_BUFFER_PATH}")
|
| 222 |
+
|
| 223 |
+
print(f"\nβ
GNN weights saved to: {WEIGHTS_PATH}")
|
| 224 |
+
print("=" * 60)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
main()
|
url_heuristics.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - url_heuristics.py
|
| 3 |
+
# Tier 2: Heuristic Rule Engine β 15 independent signals
|
| 4 |
+
#
|
| 5 |
+
# Pure Python regex ONLY β ZERO I/O, ZERO ML, ZERO network
|
| 6 |
+
# All regex patterns precompiled in __init__ for < 2ms latency
|
| 7 |
+
# Max raw score: 135 β normalized to 0-100
|
| 8 |
+
# Decision: score >= 80 β BLOCK | < 80 β pass to Tier 3
|
| 9 |
+
# ============================================================
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
import math
|
| 15 |
+
import time
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from typing import List, Tuple
|
| 18 |
+
from urllib.parse import urlparse, parse_qs
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class HeuristicResult:
|
| 23 |
+
"""Result from the Tier 2 heuristic scoring engine."""
|
| 24 |
+
score: int # 0-100 normalized score
|
| 25 |
+
signals: List[str] = field(default_factory=list) # human-readable triggered rules
|
| 26 |
+
raw_score: int = 0 # pre-normalization total (max 135)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
MAX_RAW_SCORE: int = 135
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class HeuristicScorer:
|
| 33 |
+
"""
|
| 34 |
+
Tier 2 Heuristic Rule Engine.
|
| 35 |
+
|
| 36 |
+
Scores URLs 0-100 across 15 independent regex/math signals.
|
| 37 |
+
All regex patterns are precompiled in __init__ (called once at startup).
|
| 38 |
+
The score() method runs all 15 checks in < 2ms on standard hardware.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self) -> None:
|
| 42 |
+
# ββ Precompile ALL regex patterns (called once) ββββββββββββββ
|
| 43 |
+
self._re_ip_hostname = re.compile(
|
| 44 |
+
r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self._suspicious_tlds: frozenset[str] = frozenset({
|
| 48 |
+
".xyz", ".tk", ".ml", ".ga", ".cf",
|
| 49 |
+
".gq", ".pw", ".top", ".click",
|
| 50 |
+
})
|
| 51 |
+
|
| 52 |
+
self._phishing_keywords: Tuple[str, ...] = (
|
| 53 |
+
"login", "verify", "secure", "update", "account",
|
| 54 |
+
"banking", "signin", "reset", "confirm", "suspend",
|
| 55 |
+
"webscr", "cmd", "payment", "alert",
|
| 56 |
+
)
|
| 57 |
+
self._re_phishing_keywords = re.compile(
|
| 58 |
+
r"(?:" + "|".join(re.escape(kw) for kw in self._phishing_keywords) + r")",
|
| 59 |
+
re.IGNORECASE,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self._brand_names: Tuple[str, ...] = (
|
| 63 |
+
"paypal", "google", "apple", "microsoft", "amazon",
|
| 64 |
+
"netflix", "facebook", "instagram", "chase",
|
| 65 |
+
"wellsfargo", "bankofamerica",
|
| 66 |
+
)
|
| 67 |
+
self._brand_legitimate_domains: dict[str, frozenset[str]] = {
|
| 68 |
+
brand: frozenset({
|
| 69 |
+
f"{brand}.com", f"www.{brand}.com",
|
| 70 |
+
f"{brand}.org", f"{brand}.net",
|
| 71 |
+
})
|
| 72 |
+
for brand in self._brand_names
|
| 73 |
+
}
|
| 74 |
+
self._re_brands = re.compile(
|
| 75 |
+
r"(?:" + "|".join(re.escape(b) for b in self._brand_names) + r")",
|
| 76 |
+
re.IGNORECASE,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self._re_non_standard_port = re.compile(
|
| 80 |
+
r":(\d+)", re.IGNORECASE,
|
| 81 |
+
)
|
| 82 |
+
self._standard_ports: frozenset[int] = frozenset({80, 443, 8080})
|
| 83 |
+
|
| 84 |
+
self._re_double_slash = re.compile(r"(?<=.)//")
|
| 85 |
+
|
| 86 |
+
self._re_url_encoded = re.compile(r"%[0-9A-Fa-f]{2}")
|
| 87 |
+
|
| 88 |
+
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
|
| 90 |
+
def score(self, url: str) -> HeuristicResult:
|
| 91 |
+
"""
|
| 92 |
+
Score a raw URL string from 0-100 for phishing probability.
|
| 93 |
+
Runs all 15 checks. Returns HeuristicResult with signals.
|
| 94 |
+
"""
|
| 95 |
+
raw_score: int = 0
|
| 96 |
+
signals: List[str] = []
|
| 97 |
+
|
| 98 |
+
# Parse URL once, reuse across all checks
|
| 99 |
+
try:
|
| 100 |
+
parsed = urlparse(url if "://" in url else f"http://{url}")
|
| 101 |
+
except Exception:
|
| 102 |
+
return HeuristicResult(score=0, signals=["parse_error"], raw_score=0)
|
| 103 |
+
|
| 104 |
+
hostname: str = (parsed.hostname or "").lower()
|
| 105 |
+
path: str = parsed.path or ""
|
| 106 |
+
query: str = parsed.query or ""
|
| 107 |
+
url_lower: str = url.lower()
|
| 108 |
+
|
| 109 |
+
# Extract domain (without subdomains) for brand check
|
| 110 |
+
host_parts = hostname.split(".")
|
| 111 |
+
domain = ".".join(host_parts[-2:]) if len(host_parts) >= 2 else hostname
|
| 112 |
+
|
| 113 |
+
# Run all 15 checks
|
| 114 |
+
checks: List[Tuple[int, str]] = [
|
| 115 |
+
self._check_ip_hostname(hostname),
|
| 116 |
+
self._check_suspicious_tld(hostname),
|
| 117 |
+
self._check_phishing_keywords(url_lower),
|
| 118 |
+
self._check_brand_spoofing(url_lower, domain),
|
| 119 |
+
self._check_subdomain_depth(hostname),
|
| 120 |
+
self._check_url_length(url),
|
| 121 |
+
self._check_domain_length(hostname),
|
| 122 |
+
self._check_hyphen_count(hostname),
|
| 123 |
+
self._check_digit_ratio(hostname),
|
| 124 |
+
self._check_shannon_entropy(hostname),
|
| 125 |
+
self._check_non_standard_port(parsed.netloc),
|
| 126 |
+
self._check_double_slash_redirect(path),
|
| 127 |
+
self._check_url_encoding(url),
|
| 128 |
+
self._check_query_length(query),
|
| 129 |
+
self._check_path_depth(path),
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
for points, signal in checks:
|
| 133 |
+
if points > 0:
|
| 134 |
+
raw_score += points
|
| 135 |
+
signals.append(signal)
|
| 136 |
+
|
| 137 |
+
# Normalize: raw_score / 135 * 100, capped at 100
|
| 138 |
+
normalized = min(round(raw_score / MAX_RAW_SCORE * 100), 100)
|
| 139 |
+
|
| 140 |
+
return HeuristicResult(
|
| 141 |
+
score=normalized,
|
| 142 |
+
signals=signals,
|
| 143 |
+
raw_score=raw_score,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# ββ 15 Individual Signal Checks ββββββββββββββββββββββββββββββββββ
|
| 147 |
+
|
| 148 |
+
def _check_ip_hostname(self, hostname: str) -> Tuple[int, str]:
|
| 149 |
+
"""Signal 1: IP address as hostname (25 points)."""
|
| 150 |
+
if self._re_ip_hostname.match(hostname):
|
| 151 |
+
return 25, "IP address as hostname"
|
| 152 |
+
return 0, ""
|
| 153 |
+
|
| 154 |
+
def _check_suspicious_tld(self, hostname: str) -> Tuple[int, str]:
|
| 155 |
+
"""Signal 2: Suspicious/cheap TLD (20 points)."""
|
| 156 |
+
for tld in self._suspicious_tlds:
|
| 157 |
+
if hostname.endswith(tld):
|
| 158 |
+
return 20, f"Suspicious TLD ({tld})"
|
| 159 |
+
return 0, ""
|
| 160 |
+
|
| 161 |
+
def _check_phishing_keywords(self, url_lower: str) -> Tuple[int, str]:
|
| 162 |
+
"""Signal 3: Phishing keywords in URL (15 points)."""
|
| 163 |
+
matches = self._re_phishing_keywords.findall(url_lower)
|
| 164 |
+
if matches:
|
| 165 |
+
unique = set(m.lower() for m in matches)
|
| 166 |
+
return 15, f"Phishing keywords: {', '.join(sorted(unique))}"
|
| 167 |
+
return 0, ""
|
| 168 |
+
|
| 169 |
+
def _check_brand_spoofing(self, url_lower: str, domain: str) -> Tuple[int, str]:
|
| 170 |
+
"""Signal 4: Brand name in URL but wrong domain (15 points)."""
|
| 171 |
+
matches = self._re_brands.findall(url_lower)
|
| 172 |
+
for brand_match in matches:
|
| 173 |
+
brand = brand_match.lower()
|
| 174 |
+
legit_domains = self._brand_legitimate_domains.get(brand, frozenset())
|
| 175 |
+
# Check if the domain IS the legitimate brand domain
|
| 176 |
+
if domain not in legit_domains and f"www.{domain}" not in legit_domains:
|
| 177 |
+
return 15, f"Brand spoofing: '{brand}' on non-brand domain"
|
| 178 |
+
return 0, ""
|
| 179 |
+
|
| 180 |
+
def _check_subdomain_depth(self, hostname: str) -> Tuple[int, str]:
|
| 181 |
+
"""Signal 5: Excessive subdomains >= 3 (10 points)."""
|
| 182 |
+
parts = hostname.split(".")
|
| 183 |
+
subdomain_count = max(0, len(parts) - 2)
|
| 184 |
+
if subdomain_count >= 3:
|
| 185 |
+
return 10, f"Excessive subdomains ({subdomain_count})"
|
| 186 |
+
return 0, ""
|
| 187 |
+
|
| 188 |
+
def _check_url_length(self, url: str) -> Tuple[int, str]:
|
| 189 |
+
"""Signal 6: URL length > 100 chars (5 points)."""
|
| 190 |
+
if len(url) > 100:
|
| 191 |
+
return 5, f"Long URL ({len(url)} chars)"
|
| 192 |
+
return 0, ""
|
| 193 |
+
|
| 194 |
+
def _check_domain_length(self, hostname: str) -> Tuple[int, str]:
|
| 195 |
+
"""Signal 7: Domain length > 30 chars (5 points)."""
|
| 196 |
+
if len(hostname) > 30:
|
| 197 |
+
return 5, f"Long domain ({len(hostname)} chars)"
|
| 198 |
+
return 0, ""
|
| 199 |
+
|
| 200 |
+
def _check_hyphen_count(self, hostname: str) -> Tuple[int, str]:
|
| 201 |
+
"""Signal 8: Hyphen count >= 3 in domain (5 points)."""
|
| 202 |
+
count = hostname.count("-")
|
| 203 |
+
if count >= 3:
|
| 204 |
+
return 5, f"Excessive hyphens in domain ({count})"
|
| 205 |
+
return 0, ""
|
| 206 |
+
|
| 207 |
+
def _check_digit_ratio(self, hostname: str) -> Tuple[int, str]:
|
| 208 |
+
"""Signal 9: Digit ratio in domain > 0.3 (5 points)."""
|
| 209 |
+
if not hostname:
|
| 210 |
+
return 0, ""
|
| 211 |
+
digits = sum(1 for c in hostname if c.isdigit())
|
| 212 |
+
ratio = digits / len(hostname)
|
| 213 |
+
if ratio > 0.3:
|
| 214 |
+
return 5, f"High digit ratio in domain ({ratio:.2f})"
|
| 215 |
+
return 0, ""
|
| 216 |
+
|
| 217 |
+
def _check_shannon_entropy(self, hostname: str) -> Tuple[int, str]:
|
| 218 |
+
"""Signal 10: High Shannon entropy > 3.5 (5 points)."""
|
| 219 |
+
if not hostname:
|
| 220 |
+
return 0, ""
|
| 221 |
+
length = len(hostname)
|
| 222 |
+
freq: dict[str, int] = {}
|
| 223 |
+
for c in hostname:
|
| 224 |
+
freq[c] = freq.get(c, 0) + 1
|
| 225 |
+
entropy = -sum(
|
| 226 |
+
(count / length) * math.log2(count / length)
|
| 227 |
+
for count in freq.values()
|
| 228 |
+
if count > 0
|
| 229 |
+
)
|
| 230 |
+
if entropy > 3.5:
|
| 231 |
+
return 5, f"High entropy domain ({entropy:.2f})"
|
| 232 |
+
return 0, ""
|
| 233 |
+
|
| 234 |
+
def _check_non_standard_port(self, netloc: str) -> Tuple[int, str]:
|
| 235 |
+
"""Signal 11: Non-standard port in URL (5 points)."""
|
| 236 |
+
match = self._re_non_standard_port.search(netloc)
|
| 237 |
+
if match:
|
| 238 |
+
port = int(match.group(1))
|
| 239 |
+
if port not in self._standard_ports:
|
| 240 |
+
return 5, f"Non-standard port (:{port})"
|
| 241 |
+
return 0, ""
|
| 242 |
+
|
| 243 |
+
def _check_double_slash_redirect(self, path: str) -> Tuple[int, str]:
|
| 244 |
+
"""Signal 12: Double slash redirect in path (5 points)."""
|
| 245 |
+
if self._re_double_slash.search(path):
|
| 246 |
+
return 5, "Double-slash redirect in path"
|
| 247 |
+
return 0, ""
|
| 248 |
+
|
| 249 |
+
def _check_url_encoding(self, url: str) -> Tuple[int, str]:
|
| 250 |
+
"""Signal 13: URL-encoded characters > 5 (5 points)."""
|
| 251 |
+
encoded_chars = self._re_url_encoded.findall(url)
|
| 252 |
+
if len(encoded_chars) > 5:
|
| 253 |
+
return 5, f"Excessive URL encoding ({len(encoded_chars)} encoded chars)"
|
| 254 |
+
return 0, ""
|
| 255 |
+
|
| 256 |
+
def _check_query_length(self, query: str) -> Tuple[int, str]:
|
| 257 |
+
"""Signal 14: Query string length > 200 (3 points)."""
|
| 258 |
+
if len(query) > 200:
|
| 259 |
+
return 3, f"Long query string ({len(query)} chars)"
|
| 260 |
+
return 0, ""
|
| 261 |
+
|
| 262 |
+
def _check_path_depth(self, path: str) -> Tuple[int, str]:
|
| 263 |
+
"""Signal 15: Path depth > 6 levels (2 points)."""
|
| 264 |
+
segments = [s for s in path.split("/") if s]
|
| 265 |
+
if len(segments) > 6:
|
| 266 |
+
return 2, f"Deep path ({len(segments)} levels)"
|
| 267 |
+
return 0, ""
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ββ Legacy compatibility wrapper βββββββββββββββββββββββββββββββββββββ
|
| 271 |
+
_default_scorer = HeuristicScorer()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def analyze_url(url: str) -> dict:
|
| 275 |
+
"""
|
| 276 |
+
Legacy-compatible wrapper around HeuristicScorer.
|
| 277 |
+
Returns dict with 'score', 'flags', 'is_suspicious' for backward compat.
|
| 278 |
+
"""
|
| 279 |
+
result = _default_scorer.score(url)
|
| 280 |
+
return {
|
| 281 |
+
"score": result.score,
|
| 282 |
+
"flags": result.signals,
|
| 283 |
+
"is_suspicious": result.score >= 40,
|
| 284 |
+
"raw_score": result.raw_score,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ββ Benchmark (run directly to test latency) βββββββββββββββββββββββββ
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
test_urls = [
|
| 291 |
+
"https://www.google.com",
|
| 292 |
+
"http://192.168.1.1/admin/login",
|
| 293 |
+
"https://paypal-secure-login.xyz/verify/account?id=12345",
|
| 294 |
+
"https://a.b.c.d.evil.com/login/secure/update/verify",
|
| 295 |
+
"https://secure-login-bank-now.tk:9999/path/a/b/c/d/e/f/g.php",
|
| 296 |
+
"http://xn--pypal-4ve.com/signin?token=%2F%40%3A%20%2F%40%3A%20extra&" + "a" * 200,
|
| 297 |
+
"https://www.amazon.com/dp/B0something",
|
| 298 |
+
"https://microsft-l0gin-verfy.ga/account/suspended//evil.com/reset",
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
scorer = HeuristicScorer()
|
| 302 |
+
print("=" * 70)
|
| 303 |
+
print("PhishGuard Tier 2 β Heuristic Benchmark")
|
| 304 |
+
print("=" * 70)
|
| 305 |
+
|
| 306 |
+
times: list[float] = []
|
| 307 |
+
for url in test_urls:
|
| 308 |
+
start = time.perf_counter()
|
| 309 |
+
result = scorer.score(url)
|
| 310 |
+
elapsed_us = (time.perf_counter() - start) * 1_000_000
|
| 311 |
+
|
| 312 |
+
times.append(elapsed_us)
|
| 313 |
+
action = "BLOCK" if result.score >= 80 else "β Tier 3"
|
| 314 |
+
print(f"\n URL: {url[:80]}...")
|
| 315 |
+
print(f" Score: {result.score}/100 (raw {result.raw_score}/{MAX_RAW_SCORE}) β {action}")
|
| 316 |
+
if result.signals:
|
| 317 |
+
for sig in result.signals:
|
| 318 |
+
print(f" β‘ {sig}")
|
| 319 |
+
print(f" Latency: {elapsed_us:.0f}Β΅s")
|
| 320 |
+
|
| 321 |
+
print("\n" + "=" * 70)
|
| 322 |
+
print(f" Avg latency: {sum(times)/len(times):.0f}Β΅s")
|
| 323 |
+
print(f" Max latency: {max(times):.0f}Β΅s")
|
| 324 |
+
print(f" Target: < 2000Β΅s (2ms)")
|
| 325 |
+
print(f" Result: {'β
PASS' if max(times) < 2000 else 'β FAIL'}")
|
| 326 |
+
print("=" * 70)
|
visual_analyzer.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# PhishGuard AI - visual_analyzer.py
|
| 3 |
+
# Takes a screenshot of a webpage using a headless browser
|
| 4 |
+
# and analyzes it for visual phishing indicators.
|
| 5 |
+
#
|
| 6 |
+
# Screenshot parameters (from architecture doc 2.3):
|
| 7 |
+
# Viewport: 1280Γ800 (standard desktop resolution)
|
| 8 |
+
# Timeout: 10s (prevent hanging on slow/malicious pages)
|
| 9 |
+
# Wait: domcontentloaded (faster than networkidle)
|
| 10 |
+
# Blocked: fonts, media, video (60-70% faster load)
|
| 11 |
+
# User-Agent: Chrome 120 string (avoid bot detection)
|
| 12 |
+
#
|
| 13 |
+
# Tier 4 is OPTIONAL β controlled by env var ENABLE_VISUAL_TIER.
|
| 14 |
+
# Set ENABLE_VISUAL_TIER=1 to enable.
|
| 15 |
+
# Unset / set 0 β tier 4 is skipped with "tier4_disabled".
|
| 16 |
+
#
|
| 17 |
+
# Render.com: If deploying with Playwright, your render.yaml
|
| 18 |
+
# build command must install Chromium deps. See render.yaml
|
| 19 |
+
# comments and the Dockerfile for required apt packages.
|
| 20 |
+
#
|
| 21 |
+
# Latency budget: < 200ms for screenshot capture
|
| 22 |
+
# ============================================================
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import re
|
| 28 |
+
import time
|
| 29 |
+
import hashlib
|
| 30 |
+
import logging
|
| 31 |
+
from urllib.parse import urlparse
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger("phishguard.visual")
|
| 34 |
+
|
| 35 |
+
# ββ Environment gate βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
ENABLE_VISUAL_TIER = os.environ.get("ENABLE_VISUAL_TIER", "0").strip() in ("1", "true", "yes")
|
| 37 |
+
|
| 38 |
+
if not ENABLE_VISUAL_TIER:
|
| 39 |
+
print("[PhishGuard] Tier 4 visual analysis DISABLED (set ENABLE_VISUAL_TIER=1 to enable)")
|
| 40 |
+
|
| 41 |
+
# ββ Playwright availability ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
PLAYWRIGHT_AVAILABLE = False
|
| 43 |
+
if ENABLE_VISUAL_TIER:
|
| 44 |
+
try:
|
| 45 |
+
from playwright.async_api import async_playwright
|
| 46 |
+
PLAYWRIGHT_AVAILABLE = True
|
| 47 |
+
print("[PhishGuard] Playwright available β screenshot capture enabled")
|
| 48 |
+
except ImportError:
|
| 49 |
+
print("[PhishGuard] Playwright not installed β visual analysis will use heuristic-only mode")
|
| 50 |
+
|
| 51 |
+
# ββ PIL availability βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
_pil_available = False
|
| 53 |
+
try:
|
| 54 |
+
from PIL import Image
|
| 55 |
+
import io as _io
|
| 56 |
+
_pil_available = True
|
| 57 |
+
except ImportError:
|
| 58 |
+
print("[PhishGuard] Pillow not available β color analysis disabled")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ββ Screenshot cache config ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "screenshots")
|
| 63 |
+
_CACHE_TTL = 24 * 60 * 60 # 24 hours in seconds
|
| 64 |
+
|
| 65 |
+
os.makedirs(_CACHE_DIR, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ββ Brand / financial keyword databases ββββββββββββββββββββββββββββββββββββββ
|
| 69 |
+
BRAND_DATABASE = {
|
| 70 |
+
# brand_keyword β list of legitimate domains
|
| 71 |
+
"paypal": ["paypal.com"],
|
| 72 |
+
"apple": ["apple.com", "icloud.com"],
|
| 73 |
+
"google": ["google.com", "gmail.com", "accounts.google.com"],
|
| 74 |
+
"amazon": ["amazon.com", "amazon.co.uk", "aws.amazon.com"],
|
| 75 |
+
"microsoft": ["microsoft.com", "live.com", "outlook.com", "office.com"],
|
| 76 |
+
"netflix": ["netflix.com"],
|
| 77 |
+
"facebook": ["facebook.com", "fb.com"],
|
| 78 |
+
"instagram": ["instagram.com"],
|
| 79 |
+
"chase": ["chase.com"],
|
| 80 |
+
"wellsfargo": ["wellsfargo.com"],
|
| 81 |
+
"bankofamerica": ["bankofamerica.com"],
|
| 82 |
+
"citibank": ["citibank.com", "citi.com"],
|
| 83 |
+
"hsbc": ["hsbc.com"],
|
| 84 |
+
"hdfc": ["hdfcbank.com"],
|
| 85 |
+
"icici": ["icicibank.com"],
|
| 86 |
+
"sbi": ["onlinesbi.com", "sbi.co.in"],
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
FINANCIAL_BRANDS = {
|
| 90 |
+
"paypal", "chase", "wellsfargo", "bankofamerica", "citibank",
|
| 91 |
+
"hsbc", "hdfc", "icici", "sbi", "bank", "banking",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _domain_hash(url: str) -> str:
|
| 96 |
+
"""Generate a stable hash for screenshot caching based on the domain."""
|
| 97 |
+
try:
|
| 98 |
+
parsed = urlparse(url if url.startswith("http") else "http://" + url)
|
| 99 |
+
host = parsed.hostname or url
|
| 100 |
+
return hashlib.sha256(host.encode()).hexdigest()[:16]
|
| 101 |
+
except Exception:
|
| 102 |
+
return hashlib.sha256(url.encode()).hexdigest()[:16]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _get_root_domain(url: str) -> str:
|
| 106 |
+
"""Extract root domain from URL. E.g. https://login.paypal.com β paypal.com"""
|
| 107 |
+
try:
|
| 108 |
+
parsed = urlparse(url if url.startswith("http") else "http://" + url)
|
| 109 |
+
host = (parsed.hostname or "").lower().replace("www.", "")
|
| 110 |
+
parts = host.split(".")
|
| 111 |
+
return ".".join(parts[-2:]) if len(parts) >= 2 else host
|
| 112 |
+
except Exception:
|
| 113 |
+
return ""
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# βββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββ
|
| 117 |
+
# SCREENSHOT CAPTURE (with cache)
|
| 118 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
+
|
| 120 |
+
def _get_cached_screenshot(url: str) -> bytes | None:
|
| 121 |
+
"""
|
| 122 |
+
Check if a cached screenshot exists for this domain and is < 24 hours old.
|
| 123 |
+
Returns the screenshot bytes or None.
|
| 124 |
+
"""
|
| 125 |
+
dhash = _domain_hash(url)
|
| 126 |
+
cache_path = os.path.join(_CACHE_DIR, f"{dhash}.png")
|
| 127 |
+
|
| 128 |
+
if not os.path.exists(cache_path):
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
# Check age
|
| 132 |
+
age = time.time() - os.path.getmtime(cache_path)
|
| 133 |
+
if age >= _CACHE_TTL:
|
| 134 |
+
# Expired β delete stale cache
|
| 135 |
+
try:
|
| 136 |
+
os.remove(cache_path)
|
| 137 |
+
except OSError:
|
| 138 |
+
pass
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
with open(cache_path, "rb") as f:
|
| 143 |
+
data = f.read()
|
| 144 |
+
logger.info(f"Screenshot cache HIT | url={url} | age={age:.0f}s")
|
| 145 |
+
return data
|
| 146 |
+
except Exception:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _save_screenshot_cache(url: str, data: bytes):
|
| 151 |
+
"""Save screenshot bytes to cache as screenshots/<domain_hash>.png."""
|
| 152 |
+
try:
|
| 153 |
+
dhash = _domain_hash(url)
|
| 154 |
+
cache_path = os.path.join(_CACHE_DIR, f"{dhash}.png")
|
| 155 |
+
with open(cache_path, "wb") as f:
|
| 156 |
+
f.write(data)
|
| 157 |
+
logger.info(f"Screenshot cached | url={url} | path={cache_path}")
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.warning(f"Screenshot cache write failed | error={e}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
async def take_screenshot(url: str) -> bytes | None:
|
| 163 |
+
"""
|
| 164 |
+
Open the URL in a hidden (headless) browser and take a screenshot.
|
| 165 |
+
The user never sees this browser window.
|
| 166 |
+
|
| 167 |
+
Uses a 24-hour cache: if screenshots/<domain_hash>.png exists and is
|
| 168 |
+
fresh, returns cached bytes without launching a browser.
|
| 169 |
+
|
| 170 |
+
Returns: screenshot as bytes, or None if it fails.
|
| 171 |
+
"""
|
| 172 |
+
# Gate: tier 4 disabled
|
| 173 |
+
if not ENABLE_VISUAL_TIER:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
# Check cache first
|
| 177 |
+
cached = _get_cached_screenshot(url)
|
| 178 |
+
if cached is not None:
|
| 179 |
+
return cached
|
| 180 |
+
|
| 181 |
+
# Playwright not available β can't take a fresh screenshot
|
| 182 |
+
if not PLAYWRIGHT_AVAILABLE:
|
| 183 |
+
logger.warning(f"Screenshot skipped (no Playwright) | url={url}")
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
async with async_playwright() as p:
|
| 188 |
+
browser = await p.chromium.launch(headless=True)
|
| 189 |
+
context = await browser.new_context(
|
| 190 |
+
viewport={"width": 1280, "height": 800},
|
| 191 |
+
ignore_https_errors=True,
|
| 192 |
+
user_agent=(
|
| 193 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 194 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 195 |
+
"Chrome/120.0.0.0 Safari/537.36"
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
page = await context.new_page()
|
| 199 |
+
|
| 200 |
+
# Block fonts and media to speed up loading (60-70% faster)
|
| 201 |
+
await page.route(
|
| 202 |
+
"**/*.{woff,woff2,ttf,mp4,mp3,wav}",
|
| 203 |
+
lambda route: route.abort()
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
await page.goto(url, timeout=10000, wait_until="domcontentloaded")
|
| 207 |
+
|
| 208 |
+
# ββ Extract page metadata for heuristic analysis ββββββββββ
|
| 209 |
+
page_title = await page.title() or ""
|
| 210 |
+
has_password_field = await page.locator("input[type='password']").count() > 0
|
| 211 |
+
|
| 212 |
+
screenshot = await page.screenshot(full_page=False)
|
| 213 |
+
await browser.close()
|
| 214 |
+
|
| 215 |
+
# Cache the screenshot for 24 hours
|
| 216 |
+
if screenshot:
|
| 217 |
+
_save_screenshot_cache(url, screenshot)
|
| 218 |
+
|
| 219 |
+
return screenshot
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"Screenshot failed | url={url} | error={e}")
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
async def take_screenshot_with_metadata(url: str) -> dict:
|
| 227 |
+
"""
|
| 228 |
+
Enhanced screenshot capture that also extracts page metadata
|
| 229 |
+
(title, login forms) for heuristic visual scoring.
|
| 230 |
+
|
| 231 |
+
Returns: {
|
| 232 |
+
"screenshot": bytes|None,
|
| 233 |
+
"page_title": str,
|
| 234 |
+
"has_password_field": bool,
|
| 235 |
+
"uses_https": bool,
|
| 236 |
+
"error": str|None
|
| 237 |
+
}
|
| 238 |
+
"""
|
| 239 |
+
result = {
|
| 240 |
+
"screenshot": None,
|
| 241 |
+
"page_title": "",
|
| 242 |
+
"has_password_field": False,
|
| 243 |
+
"uses_https": url.lower().startswith("https"),
|
| 244 |
+
"error": None,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# Gate: tier 4 disabled
|
| 248 |
+
if not ENABLE_VISUAL_TIER:
|
| 249 |
+
result["error"] = "tier4_disabled"
|
| 250 |
+
return result
|
| 251 |
+
|
| 252 |
+
# Check screenshot cache (metadata won't be cached, just the image)
|
| 253 |
+
cached = _get_cached_screenshot(url)
|
| 254 |
+
if cached is not None:
|
| 255 |
+
result["screenshot"] = cached
|
| 256 |
+
# We can't get page metadata from cache, but we have the image
|
| 257 |
+
return result
|
| 258 |
+
|
| 259 |
+
if not PLAYWRIGHT_AVAILABLE:
|
| 260 |
+
result["error"] = "playwright_not_available"
|
| 261 |
+
return result
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
async with async_playwright() as p:
|
| 265 |
+
browser = await p.chromium.launch(headless=True)
|
| 266 |
+
context = await browser.new_context(
|
| 267 |
+
viewport={"width": 1280, "height": 800},
|
| 268 |
+
ignore_https_errors=True,
|
| 269 |
+
user_agent=(
|
| 270 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 271 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 272 |
+
"Chrome/120.0.0.0 Safari/537.36"
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
page = await context.new_page()
|
| 276 |
+
|
| 277 |
+
await page.route(
|
| 278 |
+
"**/*.{woff,woff2,ttf,mp4,mp3,wav}",
|
| 279 |
+
lambda route: route.abort()
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
await page.goto(url, timeout=10000, wait_until="domcontentloaded")
|
| 283 |
+
|
| 284 |
+
# Extract metadata
|
| 285 |
+
result["page_title"] = await page.title() or ""
|
| 286 |
+
result["has_password_field"] = await page.locator("input[type='password']").count() > 0
|
| 287 |
+
|
| 288 |
+
screenshot = await page.screenshot(full_page=False)
|
| 289 |
+
await browser.close()
|
| 290 |
+
|
| 291 |
+
result["screenshot"] = screenshot
|
| 292 |
+
|
| 293 |
+
# Cache the screenshot
|
| 294 |
+
if screenshot:
|
| 295 |
+
_save_screenshot_cache(url, screenshot)
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
result["error"] = str(e)
|
| 299 |
+
logger.error(f"Screenshot+metadata failed | url={url} | error={e}")
|
| 300 |
+
|
| 301 |
+
return result
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 305 |
+
# VISUAL PHISHING HEURISTICS (no CNN needed)
|
| 306 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 307 |
+
|
| 308 |
+
def analyze_visual_heuristic(url: str, page_title: str = "",
|
| 309 |
+
has_password_field: bool = False) -> dict:
|
| 310 |
+
"""
|
| 311 |
+
Heuristic visual phishing scoring WITHOUT needing a trained CNN.
|
| 312 |
+
Returns heuristic_visual_score from 0.0 to 1.0 based on:
|
| 313 |
+
|
| 314 |
+
Signal 1: Page title contains brand names but domain doesn't match
|
| 315 |
+
Signal 2: Page has a login form (input[type=password])
|
| 316 |
+
Signal 3: SSL cert missing for pages mentioning financial brands
|
| 317 |
+
Signal 4: Brand keyword in URL path but not in domain (path spoofing)
|
| 318 |
+
|
| 319 |
+
Returns: {
|
| 320 |
+
heuristic_visual_score: float 0..1,
|
| 321 |
+
flags: list[str],
|
| 322 |
+
brand_mismatch: bool,
|
| 323 |
+
has_login_form: bool,
|
| 324 |
+
ssl_missing_financial: bool
|
| 325 |
+
}
|
| 326 |
+
"""
|
| 327 |
+
score = 0.0
|
| 328 |
+
flags = []
|
| 329 |
+
brand_mismatch = False
|
| 330 |
+
ssl_missing_financial = False
|
| 331 |
+
root_domain = _get_root_domain(url)
|
| 332 |
+
url_lower = url.lower()
|
| 333 |
+
title_lower = (page_title or "").lower()
|
| 334 |
+
uses_https = url_lower.startswith("https")
|
| 335 |
+
|
| 336 |
+
# ββ Signal 1: Brand name in page title but domain doesn't match βββββββ
|
| 337 |
+
for brand, legit_domains in BRAND_DATABASE.items():
|
| 338 |
+
if brand in title_lower:
|
| 339 |
+
if not any(d in root_domain for d in legit_domains):
|
| 340 |
+
score += 0.30
|
| 341 |
+
flags.append(f"title_brand_mismatch:{brand}")
|
| 342 |
+
brand_mismatch = True
|
| 343 |
+
break # One brand mismatch is enough
|
| 344 |
+
|
| 345 |
+
# ββ Signal 2: Login form detected (input[type=password]) ββββββββββββββ
|
| 346 |
+
if has_password_field:
|
| 347 |
+
score += 0.15
|
| 348 |
+
flags.append("has_password_field")
|
| 349 |
+
# Extra risk if combined with brand mismatch
|
| 350 |
+
if brand_mismatch:
|
| 351 |
+
score += 0.15
|
| 352 |
+
flags.append("login_form_with_brand_mismatch")
|
| 353 |
+
|
| 354 |
+
# ββ Signal 3: No SSL for financial brand content ββββββββββββββββββββββ
|
| 355 |
+
mentions_financial = any(
|
| 356 |
+
fb in title_lower or fb in url_lower
|
| 357 |
+
for fb in FINANCIAL_BRANDS
|
| 358 |
+
)
|
| 359 |
+
if mentions_financial and not uses_https:
|
| 360 |
+
score += 0.25
|
| 361 |
+
flags.append("no_ssl_financial_content")
|
| 362 |
+
ssl_missing_financial = True
|
| 363 |
+
|
| 364 |
+
# ββ Signal 4: Brand keyword in URL path but not in domain βββββββββββββ
|
| 365 |
+
try:
|
| 366 |
+
parsed = urlparse(url)
|
| 367 |
+
path = (parsed.path or "").lower()
|
| 368 |
+
for brand, legit_domains in BRAND_DATABASE.items():
|
| 369 |
+
if brand in path and not any(d in root_domain for d in legit_domains):
|
| 370 |
+
score += 0.15
|
| 371 |
+
flags.append(f"brand_in_path_not_domain:{brand}")
|
| 372 |
+
break
|
| 373 |
+
except Exception:
|
| 374 |
+
pass
|
| 375 |
+
|
| 376 |
+
return {
|
| 377 |
+
"heuristic_visual_score": round(min(score, 1.0), 4),
|
| 378 |
+
"flags": flags,
|
| 379 |
+
"brand_mismatch": brand_mismatch,
|
| 380 |
+
"has_login_form": has_password_field,
|
| 381 |
+
"ssl_missing_financial": ssl_missing_financial,
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def analyze_visual_basic(screenshot_bytes: bytes, url: str) -> dict:
|
| 386 |
+
"""
|
| 387 |
+
Basic visual analysis using color histograms.
|
| 388 |
+
Detects if a page uses colors associated with known brands
|
| 389 |
+
but the URL doesn't match that brand.
|
| 390 |
+
|
| 391 |
+
Note: For full CNN analysis, see cnn/cnn_model.py
|
| 392 |
+
"""
|
| 393 |
+
if not screenshot_bytes:
|
| 394 |
+
return {"visual_risk": 0.1, "note": "screenshot_failed"}
|
| 395 |
+
|
| 396 |
+
if not _pil_available:
|
| 397 |
+
return {"visual_risk": 0.1, "note": "pil_not_available"}
|
| 398 |
+
|
| 399 |
+
try:
|
| 400 |
+
img = Image.open(_io.BytesIO(screenshot_bytes)).convert("RGB")
|
| 401 |
+
img_small = img.resize((224, 224))
|
| 402 |
+
|
| 403 |
+
# Get average color channels
|
| 404 |
+
r_vals = list(img_small.split()[0].getdata())
|
| 405 |
+
g_vals = list(img_small.split()[1].getdata())
|
| 406 |
+
b_vals = list(img_small.split()[2].getdata())
|
| 407 |
+
|
| 408 |
+
r_avg = sum(r_vals) / len(r_vals)
|
| 409 |
+
g_avg = sum(g_vals) / len(g_vals)
|
| 410 |
+
b_avg = sum(b_vals) / len(b_vals)
|
| 411 |
+
|
| 412 |
+
risk = 0.2 # baseline
|
| 413 |
+
url_lower = url.lower()
|
| 414 |
+
|
| 415 |
+
# PayPal brand colors: deep blue
|
| 416 |
+
if b_avg > r_avg * 1.4 and b_avg > g_avg * 1.3:
|
| 417 |
+
if "paypal" not in url_lower:
|
| 418 |
+
risk += 0.25
|
| 419 |
+
|
| 420 |
+
# Microsoft brand colors: orange/blue
|
| 421 |
+
if r_avg > 180 and b_avg < 100:
|
| 422 |
+
if "microsoft" not in url_lower and "office" not in url_lower:
|
| 423 |
+
risk += 0.20
|
| 424 |
+
|
| 425 |
+
# Apple brand: mostly white/grey
|
| 426 |
+
if r_avg > 220 and g_avg > 220 and b_avg > 220:
|
| 427 |
+
if "apple" not in url_lower:
|
| 428 |
+
risk += 0.10
|
| 429 |
+
|
| 430 |
+
return {
|
| 431 |
+
"visual_risk": round(min(risk, 1.0), 4),
|
| 432 |
+
"dominant_rgb": [round(r_avg), round(g_avg), round(b_avg)],
|
| 433 |
+
"note": "basic_color_analysis"
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
except Exception as e:
|
| 437 |
+
return {"visual_risk": 0.1, "note": "analysis_error"}
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 441 |
+
# FULL TIER 4 ANALYSIS (combines CNN + heuristics + color)
|
| 442 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 443 |
+
|
| 444 |
+
async def run_tier4_analysis(url: str, page_title: str = "",
|
| 445 |
+
page_snippet: str = "") -> dict:
|
| 446 |
+
"""
|
| 447 |
+
Complete Tier 4 visual analysis pipeline.
|
| 448 |
+
Called by main.py for borderline cases (0.40 β€ Pβ < 0.85).
|
| 449 |
+
|
| 450 |
+
Graceful fallback chain:
|
| 451 |
+
1. If ENABLE_VISUAL_TIER is off β tier4_disabled
|
| 452 |
+
2. If screenshot fails β screenshot_failed (with heuristic fallback)
|
| 453 |
+
3. If CNN fails β uses heuristic_visual_score only
|
| 454 |
+
|
| 455 |
+
Returns: {
|
| 456 |
+
tier4_score: float|None,
|
| 457 |
+
tier4_status: str ("ok"|"screenshot_failed"|"tier4_disabled"|...),
|
| 458 |
+
tier4_reason: str,
|
| 459 |
+
visual_heuristic: dict,
|
| 460 |
+
color_analysis: dict,
|
| 461 |
+
screenshot_cached: bool
|
| 462 |
+
}
|
| 463 |
+
"""
|
| 464 |
+
# ββ Gate: completely skip if not enabled βββββββββββββββββββββββββββββββ
|
| 465 |
+
if not ENABLE_VISUAL_TIER:
|
| 466 |
+
return {
|
| 467 |
+
"tier4_score": None,
|
| 468 |
+
"tier4_status": "tier4_disabled",
|
| 469 |
+
"tier4_reason": "ENABLE_VISUAL_TIER env var not set",
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
# ββ Attempt screenshot with metadata extraction βββββββββββββββββββββββ
|
| 473 |
+
meta = await take_screenshot_with_metadata(url)
|
| 474 |
+
screenshot = meta["screenshot"]
|
| 475 |
+
extracted_title = meta["page_title"] or page_title
|
| 476 |
+
has_password = meta["has_password_field"]
|
| 477 |
+
screenshot_error = meta["error"]
|
| 478 |
+
|
| 479 |
+
# ββ Always run visual heuristics (no screenshot needed) βββββββββββββββ
|
| 480 |
+
heuristic = analyze_visual_heuristic(
|
| 481 |
+
url,
|
| 482 |
+
page_title=extracted_title,
|
| 483 |
+
has_password_field=has_password,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# ββ Screenshot failed β return heuristic-only result ββββββββββββββββββ
|
| 487 |
+
if screenshot is None:
|
| 488 |
+
reason = screenshot_error or "unknown_screenshot_error"
|
| 489 |
+
return {
|
| 490 |
+
"tier4_score": heuristic["heuristic_visual_score"],
|
| 491 |
+
"tier4_status": "screenshot_failed",
|
| 492 |
+
"tier4_reason": reason,
|
| 493 |
+
"visual_heuristic": heuristic,
|
| 494 |
+
"color_analysis": None,
|
| 495 |
+
"screenshot_cached": False,
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
# ββ Color-based analysis (works without trained CNN) ββββββββββββββββββ
|
| 499 |
+
color = analyze_visual_basic(screenshot, url)
|
| 500 |
+
|
| 501 |
+
# ββ Combine heuristic + color into a single tier4 score βββββββββββββββ
|
| 502 |
+
# Weight: 60% heuristic, 40% color (since CNN isn't trained)
|
| 503 |
+
combined = (heuristic["heuristic_visual_score"] * 0.60) + (color["visual_risk"] * 0.40)
|
| 504 |
+
|
| 505 |
+
return {
|
| 506 |
+
"tier4_score": round(min(combined, 1.0), 4),
|
| 507 |
+
"tier4_status": "ok",
|
| 508 |
+
"tier4_reason": "heuristic_and_color_analysis",
|
| 509 |
+
"visual_heuristic": heuristic,
|
| 510 |
+
"color_analysis": color,
|
| 511 |
+
"screenshot_cached": _get_cached_screenshot(url) is not None,
|
| 512 |
+
}
|
vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|