prashanth135 commited on
Commit
bebe233
Β·
verified Β·
1 Parent(s): 056ea3d

Upload 38 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONDONTWRITEBYTECODE=1 \
6
+ PYTHONUNBUFFERED=1 \
7
+ PLAYWRIGHT_BROWSERS_PATH=/ms-playwright \
8
+ HOME=/home/user \
9
+ PATH=/home/user/.local/bin:$PATH
10
+
11
+ # Create a non-root user for Hugging Face security
12
+ RUN useradd -m -u 1000 user
13
+ WORKDIR /home/user/app
14
+
15
+ # Install system dependencies required for Playwright and ML libs
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ wget curl libglib2.0-0 libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 \
18
+ libcups2 libdrm2 libdbus-1-3 libxcb1 libxkbcommon0 libx11-6 libxcomposite1 \
19
+ libxdamage1 libxext6 libxfixes3 libxrandr2 libgbm1 libpango-1.0-0 libcairo2 libasound2 \
20
+ && rm -rf /var/lib/apt/lists/*
21
+
22
+ # Install PyTorch CPU
23
+ RUN pip install --no-cache-dir torch==2.2.2 torchvision==0.17.2 --index-url https://download.pytorch.org/whl/cpu
24
+
25
+ # Copy requirements and install
26
+ COPY --chown=user requirements.txt .
27
+ RUN grep -v "torch==" requirements.txt | grep -v "torchvision==" > req_filtered.txt && \
28
+ pip install --no-cache-dir --upgrade pip && \
29
+ pip install --no-cache-dir -r req_filtered.txt
30
+
31
+ # Install Playwright browser
32
+ RUN playwright install chromium
33
+
34
+ # Copy project files
35
+ COPY --chown=user . .
36
+
37
+ # Create necessary directories and set permissions
38
+ RUN mkdir -p data logs bert_weights gnn cnn && \
39
+ chmod -R 777 data logs bert_weights
40
+
41
+ # Expose the Hugging Face default port
42
+ EXPOSE 7860
43
+
44
+ # Command to run on start (Port 7860 is required for HF)
45
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # PhishGuard AI - CNN Module
admin.html ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>PhishGuard Admin</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
9
+ * { margin:0; padding:0; box-sizing:border-box; }
10
+ :root {
11
+ --bg: #0F0F14; --bg2: #1A1A24; --card: #22222E; --border: rgba(255,255,255,0.06);
12
+ --text: #EAEAF0; --text2: #8888A0; --accent: #534AB7;
13
+ --safe: #22C55E; --danger: #EF4444; --warn: #F59E0B;
14
+ }
15
+ body { font-family:'Inter',sans-serif; background:var(--bg); color:var(--text); min-height:100vh; }
16
+
17
+ /* Login */
18
+ .login-wrap { display:flex; align-items:center; justify-content:center; min-height:100vh; }
19
+ .login-box { background:var(--card); padding:36px; border-radius:16px; border:1px solid var(--border); width:340px; }
20
+ .login-box h2 { font-size:20px; margin-bottom:20px; text-align:center; }
21
+ .login-box input { width:100%; padding:10px 14px; background:var(--bg2); border:1px solid var(--border);
22
+ border-radius:8px; color:var(--text); font-size:14px; font-family:inherit; margin-bottom:12px; outline:none; }
23
+ .login-box input:focus { border-color:var(--accent); }
24
+ .login-box button { width:100%; padding:10px; background:linear-gradient(135deg,var(--accent),#6C5ECE);
25
+ border:none; border-radius:8px; color:#fff; font-size:14px; font-weight:600; cursor:pointer; font-family:inherit; }
26
+ .login-box button:hover { opacity:0.9; }
27
+ .login-error { color:var(--danger); font-size:12px; text-align:center; margin-top:8px; display:none; }
28
+
29
+ /* Dashboard */
30
+ .dashboard { display:none; max-width:1000px; margin:0 auto; padding:24px; }
31
+ .dash-header { display:flex; align-items:center; gap:12px; margin-bottom:24px; }
32
+ .dash-header h1 { font-size:22px; flex:1; }
33
+ .dash-header h1 span { color:var(--accent); }
34
+ .logout-btn { padding:6px 16px; background:var(--card); border:1px solid var(--border);
35
+ border-radius:8px; color:var(--text2); font-size:12px; cursor:pointer; font-family:inherit; }
36
+
37
+ /* Stats Cards */
38
+ .stats { display:grid; grid-template-columns:repeat(auto-fit,minmax(180px,1fr)); gap:12px; margin-bottom:24px; }
39
+ .stat-card { background:var(--card); border:1px solid var(--border); border-radius:12px; padding:16px; }
40
+ .stat-label { font-size:11px; color:var(--text2); text-transform:uppercase; letter-spacing:0.5px; }
41
+ .stat-value { font-size:28px; font-weight:700; margin-top:4px; }
42
+ .stat-sub { font-size:11px; color:var(--text2); margin-top:2px; }
43
+
44
+ /* Table */
45
+ .section-title { font-size:16px; font-weight:600; margin-bottom:12px; }
46
+ .table-wrap { background:var(--card); border:1px solid var(--border); border-radius:12px; overflow:hidden; margin-bottom:24px; }
47
+ table { width:100%; border-collapse:collapse; font-size:12px; }
48
+ th { background:var(--bg2); padding:10px 14px; text-align:left; font-weight:600; color:var(--text2);
49
+ text-transform:uppercase; letter-spacing:0.5px; font-size:10px; }
50
+ td { padding:10px 14px; border-top:1px solid var(--border); }
51
+ tr:hover td { background:rgba(255,255,255,0.02); }
52
+ .badge { display:inline-block; padding:2px 8px; border-radius:4px; font-size:10px; font-weight:600; }
53
+ .badge-phish { background:rgba(239,68,68,0.12); color:var(--danger); }
54
+ .badge-safe { background:rgba(34,197,94,0.12); color:var(--safe); }
55
+ .url-cell { max-width:300px; overflow:hidden; text-overflow:ellipsis; white-space:nowrap; font-family:'SF Mono',monospace; font-size:11px; color:var(--text2); }
56
+
57
+ /* Retrain Button */
58
+ .retrain-bar { display:flex; align-items:center; gap:12px; margin-bottom:24px; }
59
+ .retrain-btn { padding:10px 24px; background:linear-gradient(135deg,var(--accent),#6C5ECE);
60
+ border:none; border-radius:8px; color:#fff; font-size:13px; font-weight:600; cursor:pointer; font-family:inherit; }
61
+ .retrain-btn:hover { opacity:0.9; }
62
+ .retrain-btn:disabled { opacity:0.4; cursor:not-allowed; }
63
+ .retrain-status { font-size:12px; color:var(--text2); }
64
+
65
+ /* History */
66
+ .history-card { background:var(--card); border:1px solid var(--border); border-radius:12px; padding:16px; margin-bottom:8px;
67
+ display:flex; align-items:center; gap:16px; }
68
+ .hist-version { font-size:22px; font-weight:700; color:var(--accent); min-width:48px; text-align:center; }
69
+ .hist-detail { flex:1; }
70
+ .hist-detail div:first-child { font-size:13px; font-weight:500; }
71
+ .hist-detail div:last-child { font-size:11px; color:var(--text2); margin-top:2px; }
72
+ .hist-accuracy { font-size:14px; font-weight:600; }
73
+ </style>
74
+ </head>
75
+ <body>
76
+
77
+ <!-- Login Screen -->
78
+ <div class="login-wrap" id="loginScreen">
79
+ <div class="login-box">
80
+ <h2>πŸ›‘οΈ PhishGuard Admin</h2>
81
+ <input type="password" id="passInput" placeholder="Admin password" autocomplete="off">
82
+ <button onclick="attemptLogin()">Login</button>
83
+ <div class="login-error" id="loginError">Invalid password</div>
84
+ </div>
85
+ </div>
86
+
87
+ <!-- Dashboard (hidden until login) -->
88
+ <div class="dashboard" id="dashboard">
89
+ <div class="dash-header">
90
+ <h1>Phish<span>Guard</span> Admin</h1>
91
+ <button class="logout-btn" onclick="logout()">Logout</button>
92
+ </div>
93
+
94
+ <!-- Stats -->
95
+ <div class="stats">
96
+ <div class="stat-card">
97
+ <div class="stat-label">Total Feedback</div>
98
+ <div class="stat-value" id="sTotalFeedback">β€”</div>
99
+ </div>
100
+ <div class="stat-card">
101
+ <div class="stat-label">Phishing Reports</div>
102
+ <div class="stat-value" id="sPhishing" style="color:var(--danger)">β€”</div>
103
+ </div>
104
+ <div class="stat-card">
105
+ <div class="stat-label">Safe Reports</div>
106
+ <div class="stat-value" id="sSafe" style="color:var(--safe)">β€”</div>
107
+ </div>
108
+ <div class="stat-card">
109
+ <div class="stat-label">Model Version</div>
110
+ <div class="stat-value" id="sVersion" style="color:var(--accent)">β€”</div>
111
+ <div class="stat-sub" id="sLastRetrain">Never retrained</div>
112
+ </div>
113
+ <div class="stat-card">
114
+ <div class="stat-label">Unprocessed</div>
115
+ <div class="stat-value" id="sUnprocessed" style="color:var(--warn)">β€”</div>
116
+ <div class="stat-sub">of 50 needed</div>
117
+ </div>
118
+ </div>
119
+
120
+ <!-- Manual Retrain -->
121
+ <div class="retrain-bar">
122
+ <button class="retrain-btn" id="retrainBtn" onclick="triggerRetrain()">πŸ”„ Trigger Retraining</button>
123
+ <div class="retrain-status" id="retrainStatus"></div>
124
+ </div>
125
+
126
+ <!-- Recent Feedback -->
127
+ <div class="section-title">Recent Feedback</div>
128
+ <div class="table-wrap">
129
+ <table>
130
+ <thead>
131
+ <tr><th>URL</th><th>Label</th><th>Source</th><th>Prediction</th><th>Time</th></tr>
132
+ </thead>
133
+ <tbody id="feedbackTable"><tr><td colspan="5" style="text-align:center;color:var(--text2)">Loading...</td></tr></tbody>
134
+ </table>
135
+ </div>
136
+
137
+ <!-- Retrain History -->
138
+ <div class="section-title">Retrain History</div>
139
+ <div id="historyList"><div style="color:var(--text2);font-size:12px">Loading...</div></div>
140
+ </div>
141
+
142
+ <script>
143
+ const BASE = window.location.origin;
144
+ let authToken = "";
145
+
146
+ // Login
147
+ function attemptLogin() {
148
+ const pass = document.getElementById("passInput").value;
149
+ fetch(`${BASE}/admin/login`, {
150
+ method: "POST",
151
+ headers: {"Content-Type":"application/json"},
152
+ body: JSON.stringify({password: pass})
153
+ })
154
+ .then(r => r.json())
155
+ .then(data => {
156
+ if (data.success) {
157
+ authToken = data.token;
158
+ document.getElementById("loginScreen").style.display = "none";
159
+ document.getElementById("dashboard").style.display = "block";
160
+ loadDashboard();
161
+ } else {
162
+ document.getElementById("loginError").style.display = "block";
163
+ }
164
+ })
165
+ .catch(() => { document.getElementById("loginError").style.display = "block"; });
166
+ }
167
+ document.getElementById("passInput").addEventListener("keyup", e => { if(e.key==="Enter") attemptLogin(); });
168
+
169
+ function logout() {
170
+ authToken = "";
171
+ document.getElementById("loginScreen").style.display = "flex";
172
+ document.getElementById("dashboard").style.display = "none";
173
+ }
174
+
175
+ // Dashboard data
176
+ function loadDashboard() {
177
+ // Stats
178
+ fetch(`${BASE}/admin/data?token=${authToken}`).then(r=>r.json()).then(data => {
179
+ if (data.error) { logout(); return; }
180
+ const s = data.stats;
181
+ document.getElementById("sTotalFeedback").textContent = s.total_feedback;
182
+ document.getElementById("sPhishing").textContent = s.phishing_corrections;
183
+ document.getElementById("sSafe").textContent = s.safe_corrections;
184
+ document.getElementById("sVersion").textContent = "v" + s.model_version;
185
+ document.getElementById("sUnprocessed").textContent = s.unprocessed_count;
186
+ document.getElementById("sLastRetrain").textContent = s.last_retrain
187
+ ? "Last: " + new Date(s.last_retrain).toLocaleString() : "Never retrained";
188
+
189
+ // Feedback table
190
+ const rows = data.recent.map(e => `
191
+ <tr>
192
+ <td class="url-cell" title="${esc(e.url)}">${esc(e.url)}</td>
193
+ <td><span class="badge ${e.label==='phishing'?'badge-phish':'badge-safe'}">${esc(e.label)}</span></td>
194
+ <td>${esc(e.source||'β€”')}</td>
195
+ <td>${e.original_prediction!=null ? (e.original_prediction*100).toFixed(0)+'%' : 'β€”'}</td>
196
+ <td style="font-size:11px;color:var(--text2)">${e.timestamp ? new Date(e.timestamp).toLocaleString() : 'β€”'}</td>
197
+ </tr>
198
+ `).join("");
199
+ document.getElementById("feedbackTable").innerHTML = rows || '<tr><td colspan="5" style="text-align:center;color:var(--text2)">No feedback yet</td></tr>';
200
+
201
+ // History
202
+ const hist = (s.retrain_history || []).reverse();
203
+ if (hist.length === 0) {
204
+ document.getElementById("historyList").innerHTML = '<div style="color:var(--text2);font-size:12px">No retraining history</div>';
205
+ } else {
206
+ document.getElementById("historyList").innerHTML = hist.map(h => `
207
+ <div class="history-card">
208
+ <div class="hist-version">v${h.version}</div>
209
+ <div class="hist-detail">
210
+ <div>Trained on ${h.samples} samples</div>
211
+ <div>${new Date(h.timestamp).toLocaleString()}</div>
212
+ </div>
213
+ <div class="hist-accuracy" style="color:${h.accuracy>=0.8?'var(--safe)':h.accuracy>=0.6?'var(--warn)':'var(--danger)'}">
214
+ ${(h.accuracy*100).toFixed(1)}%
215
+ </div>
216
+ </div>
217
+ `).join("");
218
+ }
219
+ });
220
+ }
221
+
222
+ // Retrain
223
+ function triggerRetrain() {
224
+ const btn = document.getElementById("retrainBtn");
225
+ btn.disabled = true;
226
+ btn.textContent = "⏳ Retraining...";
227
+ document.getElementById("retrainStatus").textContent = "Training in progress...";
228
+
229
+ fetch(`${BASE}/admin/retrain?token=${authToken}`, {method:"POST"})
230
+ .then(r=>r.json())
231
+ .then(data => {
232
+ document.getElementById("retrainStatus").textContent = data.message || "Done";
233
+ btn.disabled = false;
234
+ btn.textContent = "πŸ”„ Trigger Retraining";
235
+ setTimeout(loadDashboard, 2000);
236
+ })
237
+ .catch(e => {
238
+ document.getElementById("retrainStatus").textContent = "Error: " + e.message;
239
+ btn.disabled = false;
240
+ btn.textContent = "πŸ”„ Trigger Retraining";
241
+ });
242
+ }
243
+
244
+ function esc(s) { const d=document.createElement('div'); d.textContent=String(s||''); return d.innerHTML; }
245
+
246
+ // Auto-refresh every 30s
247
+ setInterval(() => { if(authToken) loadDashboard(); }, 30000);
248
+ </script>
249
+ </body>
250
+ </html>
background.js ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ============================================================
2
+ // PhishGuard AI - background.js
3
+ // MV3 Service Worker with feedback, retraining triggers, and
4
+ // model version polling.
5
+ //
6
+ // State (chrome.storage.local):
7
+ // phishguard_feedback_queue: FeedbackRecord[] (max 500, FIFO)
8
+ // scan_count: int (resets at 50)
9
+ // feedback_count: int (labeled samples since last retrain)
10
+ // last_retrain_ts: ISO8601
11
+ // model_version: int
12
+ // session_id: UUIDv4
13
+ //
14
+ // Triggers:
15
+ // 1. scan_count >= 50 AND feedback_count >= 10
16
+ // 2. chrome.alarms "retrain_alarm" (24h) AND feedback_count >= 10
17
+ // ============================================================
18
+
19
+ // ── Backend URL ──────────────────────────────────────────────────────
20
+ const BACKEND_URL = "https://phishguard-api-z2wj.onrender.com";
21
+ const ANALYZE_URL = `${BACKEND_URL}/analyze`;
22
+ const RETRAIN_URL = `${BACKEND_URL}/retrain`;
23
+ const MODEL_VERSION_URL = `${BACKEND_URL}/model_version`;
24
+
25
+ // ── Constants ────────────────────────────────────────────────────────
26
+ const CACHE_TTL_MS = 30 * 60 * 1000;
27
+ const MAX_QUEUE_SIZE = 500;
28
+ const RETRAIN_URL_THRESHOLD = 50;
29
+ const MIN_LABELED_SAMPLES = 10;
30
+
31
+ // ── In-memory caches ─────────────────────────────────────────────────
32
+ const urlCache = new Map();
33
+ const tabResultCache = new Map();
34
+ const pageSignals = new Map();
35
+
36
+ // ── TIER 1: Whitelist (O(1) Set lookup) ──────────────────────────────
37
+ const WHITELIST = new Set([
38
+ "google.com","youtube.com","facebook.com","amazon.com","wikipedia.org",
39
+ "twitter.com","instagram.com","linkedin.com","microsoft.com","apple.com",
40
+ "github.com","stackoverflow.com","reddit.com","netflix.com","paypal.com",
41
+ "bankofamerica.com","chase.com","wellsfargo.com","yahoo.com","bing.com",
42
+ "outlook.com","office.com","live.com","adobe.com","dropbox.com",
43
+ "zoom.us","slack.com","spotify.com","twitch.tv","ebay.com",
44
+ "walmart.com","target.com","bestbuy.com","airbnb.com",
45
+ "x.com","tiktok.com","pinterest.com","quora.com","medium.com"
46
+ ]);
47
+
48
+ function getRootDomain(url) {
49
+ try {
50
+ const host = new URL(url).hostname.replace(/^www\./, "");
51
+ const parts = host.split(".");
52
+ return parts.slice(-2).join(".");
53
+ } catch { return null; }
54
+ }
55
+
56
+ // ── TIER 2: Local heuristic scoring ──────────────────────────────────
57
+ function heuristicScore(url) {
58
+ let score = 0;
59
+ const signals = [];
60
+ const u = url.toLowerCase();
61
+
62
+ // IP as hostname (25 pts)
63
+ if (/https?:\/\/\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/.test(url)) {
64
+ score += 25; signals.push("IP as hostname");
65
+ }
66
+
67
+ // Suspicious TLD (20 pts)
68
+ const badTLDs = [".xyz",".tk",".ml",".ga",".cf",".gq",".pw",".top",".click"];
69
+ for (const tld of badTLDs) {
70
+ if (u.includes(tld)) { score += 20; signals.push(`Suspicious TLD (${tld})`); break; }
71
+ }
72
+
73
+ // Phishing keywords (15 pts)
74
+ const keywords = ["login","verify","secure","update","account","banking",
75
+ "signin","reset","confirm","suspend","webscr","cmd","payment","alert"];
76
+ const kwHits = keywords.filter(kw => u.includes(kw));
77
+ if (kwHits.length > 0) { score += 15; signals.push(`Keywords: ${kwHits.join(", ")}`); }
78
+
79
+ // Brand spoofing (15 pts)
80
+ const brands = ["paypal","google","apple","microsoft","amazon","netflix",
81
+ "facebook","instagram","chase","wellsfargo","bankofamerica"];
82
+ try {
83
+ const domain = getRootDomain(url);
84
+ for (const brand of brands) {
85
+ if (u.includes(brand) && domain && !domain.startsWith(brand)) {
86
+ score += 15; signals.push(`Brand spoofing: ${brand}`); break;
87
+ }
88
+ }
89
+ } catch {}
90
+
91
+ // Excessive subdomains (10 pts)
92
+ try {
93
+ const host = new URL(url).hostname;
94
+ const subCount = host.split(".").length - 2;
95
+ if (subCount >= 3) { score += 10; signals.push(`${subCount} subdomains`); }
96
+ } catch {}
97
+
98
+ // URL length (5 pts)
99
+ if (url.length > 100) { score += 5; signals.push(`Long URL (${url.length} chars)`); }
100
+
101
+ // Hyphens (5 pts)
102
+ try {
103
+ const host = new URL(url).hostname;
104
+ const hyphens = (host.match(/-/g) || []).length;
105
+ if (hyphens >= 3) { score += 5; signals.push(`${hyphens} hyphens in domain`); }
106
+ } catch {}
107
+
108
+ // Non-standard port (5 pts)
109
+ try {
110
+ const port = new URL(url).port;
111
+ if (port && port !== "80" && port !== "443") {
112
+ score += 5; signals.push(`Non-standard port :${port}`);
113
+ }
114
+ } catch {}
115
+
116
+ return { score: Math.min(score, 100), signals };
117
+ }
118
+
119
+ // ── URL Cache ────────────────────────────────────────────────────────
120
+ function getCached(url) {
121
+ const entry = urlCache.get(url);
122
+ if (!entry) return null;
123
+ if (Date.now() - entry.ts > CACHE_TTL_MS) { urlCache.delete(url); return null; }
124
+ return entry.result;
125
+ }
126
+
127
+ function setCache(url, result) {
128
+ urlCache.set(url, { result, ts: Date.now() });
129
+ if (urlCache.size > 500) {
130
+ const firstKey = urlCache.keys().next().value;
131
+ urlCache.delete(firstKey);
132
+ }
133
+ }
134
+
135
+ // ── Badge ────────────────────────────────────────────────────────────
136
+ function setBadge(tabId, status, text) {
137
+ const colors = {
138
+ safe: "#22C55E", blocked: "#EF4444", warn: "#F59E0B",
139
+ loading: "#534AB7", none: "#888888"
140
+ };
141
+ chrome.action.setBadgeBackgroundColor({ color: colors[status] || colors.none, tabId });
142
+ chrome.action.setBadgeText({ text: text || "", tabId });
143
+ }
144
+
145
+ // ── Backend fetch with retry ─────────────────────────────────────────
146
+ async function fetchBackend(url, payload, retryCount = 1) {
147
+ try {
148
+ const controller = new AbortController();
149
+ const timeout = setTimeout(() => controller.abort(), 15000);
150
+ const response = await fetch(url, {
151
+ method: "POST",
152
+ headers: { "Content-Type": "application/json" },
153
+ body: JSON.stringify(payload),
154
+ signal: controller.signal,
155
+ });
156
+ clearTimeout(timeout);
157
+ if (!response.ok) throw new Error(`Server ${response.status}`);
158
+ return await response.json();
159
+ } catch (err) {
160
+ if (retryCount > 0) {
161
+ await new Promise(r => setTimeout(r, 2000));
162
+ return fetchBackend(url, payload, retryCount - 1);
163
+ }
164
+ throw err;
165
+ }
166
+ }
167
+
168
+ // ── SHA256 hash ──────────────────────────────────────────────────────
169
+ async function sha256(text) {
170
+ const encoded = new TextEncoder().encode(text);
171
+ const hash = await crypto.subtle.digest("SHA-256", encoded);
172
+ return Array.from(new Uint8Array(hash)).map(b => b.toString(16).padStart(2, "0")).join("");
173
+ }
174
+
175
+ // ── Storage helpers ──────────────────────────────────────────────────
176
+ async function getStorage(keys) {
177
+ return new Promise(resolve => chrome.storage.local.get(keys, resolve));
178
+ }
179
+
180
+ async function setStorage(data) {
181
+ return new Promise(resolve => chrome.storage.local.set(data, resolve));
182
+ }
183
+
184
+ async function getQueue() {
185
+ const data = await getStorage(["phishguard_feedback_queue"]);
186
+ return data.phishguard_feedback_queue || [];
187
+ }
188
+
189
+ async function setQueue(queue) {
190
+ // FIFO eviction
191
+ if (queue.length > MAX_QUEUE_SIZE) {
192
+ queue = queue.slice(queue.length - MAX_QUEUE_SIZE);
193
+ }
194
+ await setStorage({ phishguard_feedback_queue: queue });
195
+ }
196
+
197
+ // ── ON INSTALL ───────────────────────────────────────────────────────
198
+ chrome.runtime.onInstalled.addListener(async () => {
199
+ const sessionId = crypto.randomUUID();
200
+ await setStorage({
201
+ session_id: sessionId,
202
+ scan_count: 0,
203
+ feedback_count: 0,
204
+ last_retrain_ts: null,
205
+ model_version: 0,
206
+ phishguard_feedback_queue: [],
207
+ });
208
+
209
+ // 24-hour retraining alarm
210
+ chrome.alarms.create("retrain_alarm", { periodInMinutes: 1440 });
211
+ // 30-minute model polling alarm
212
+ chrome.alarms.create("model_poll_alarm", { periodInMinutes: 30 });
213
+
214
+ console.log("[PhishGuard] Installed. Session:", sessionId);
215
+ });
216
+
217
+ // ── ALARM HANDLERS ───────────────────────────────────────────────────
218
+ chrome.alarms.onAlarm.addListener(async (alarm) => {
219
+ if (alarm.name === "retrain_alarm") {
220
+ console.log("[PhishGuard] Retrain alarm fired");
221
+ await checkRetrain("timer");
222
+ }
223
+ if (alarm.name === "model_poll_alarm") {
224
+ await pollModelVersion();
225
+ }
226
+ });
227
+
228
+ // ── MAIN URL LISTENER ────────────────────────────────────────────────
229
+ chrome.webNavigation.onCompleted.addListener(async (details) => {
230
+ if (details.frameId !== 0) return;
231
+ const url = details.url;
232
+ if (!url.startsWith("http")) return;
233
+
234
+ const tabId = details.tabId;
235
+ const domain = getRootDomain(url);
236
+ if (!domain) return;
237
+
238
+ setBadge(tabId, "loading", "…");
239
+
240
+ // TIER 1: Whitelist
241
+ if (WHITELIST.has(domain)) {
242
+ const result = {
243
+ url, status: "safe", tier: 1, method: "whitelist",
244
+ confidence: 0, heuristic_score: 0, signals: []
245
+ };
246
+ await setStorage({ lastResult: result });
247
+ tabResultCache.set(tabId, result);
248
+ setBadge(tabId, "safe", "βœ“");
249
+ return;
250
+ }
251
+
252
+ // Cache check
253
+ const cached = getCached(url);
254
+ if (cached) {
255
+ await setStorage({ lastResult: cached });
256
+ tabResultCache.set(tabId, cached);
257
+ setBadge(tabId, cached.status, cached.status === "blocked" ? "!" : "βœ“");
258
+ if (cached.status === "blocked") blockPage(tabId, url, cached);
259
+ return;
260
+ }
261
+
262
+ // TIER 2: Heuristic
263
+ const hResult = heuristicScore(url);
264
+
265
+ if (hResult.score >= 80) {
266
+ const result = {
267
+ url, status: "blocked", tier: 2, method: "heuristic",
268
+ confidence: hResult.score / 100, heuristic_score: hResult.score,
269
+ signals: hResult.signals, is_phishing: true
270
+ };
271
+ setCache(url, result);
272
+ await setStorage({ lastResult: result });
273
+ tabResultCache.set(tabId, result);
274
+ setBadge(tabId, "blocked", "!");
275
+ blockPage(tabId, url, result);
276
+ await storeFeedbackRecord(url, result);
277
+ await incrementScanCount();
278
+ return;
279
+ }
280
+
281
+ // TIER 3+4: Send to backend
282
+ const signals = pageSignals.get(tabId) || {};
283
+ try {
284
+ const apiResult = await fetchBackend(ANALYZE_URL, {
285
+ url,
286
+ heuristic_score: hResult.score,
287
+ page_title: signals.title || "",
288
+ page_snippet: signals.snippet || "",
289
+ });
290
+
291
+ const finalResult = {
292
+ url,
293
+ status: apiResult.is_phishing ? "blocked" : "safe",
294
+ tier: apiResult.tier || 3,
295
+ method: apiResult.method || "ensemble",
296
+ confidence: apiResult.confidence || 0,
297
+ heuristic_score: apiResult.heuristic_score || hResult.score,
298
+ signals: apiResult.signals || hResult.signals,
299
+ is_phishing: apiResult.is_phishing,
300
+ details: apiResult.details || {},
301
+ };
302
+
303
+ setCache(url, finalResult);
304
+ await setStorage({ lastResult: finalResult });
305
+ tabResultCache.set(tabId, finalResult);
306
+
307
+ if (finalResult.status === "blocked") {
308
+ setBadge(tabId, "blocked", "!");
309
+ blockPage(tabId, url, finalResult);
310
+ } else if (finalResult.confidence >= 0.4) {
311
+ setBadge(tabId, "warn", "?");
312
+ } else {
313
+ setBadge(tabId, "safe", "βœ“");
314
+ }
315
+
316
+ await storeFeedbackRecord(url, finalResult);
317
+
318
+ } catch (err) {
319
+ console.log("[PhishGuard] Backend unreachable:", err.message);
320
+ const fallback = {
321
+ url,
322
+ status: hResult.score >= 50 ? "blocked" : "safe",
323
+ tier: 2,
324
+ method: "heuristic-fallback",
325
+ confidence: hResult.score / 100,
326
+ heuristic_score: hResult.score,
327
+ signals: hResult.signals,
328
+ is_phishing: hResult.score >= 50,
329
+ details: { backend_error: err.message },
330
+ };
331
+ setCache(url, fallback);
332
+ await setStorage({ lastResult: fallback });
333
+ tabResultCache.set(tabId, fallback);
334
+
335
+ if (hResult.score >= 50) {
336
+ setBadge(tabId, "blocked", "!");
337
+ blockPage(tabId, url, fallback);
338
+ } else if (hResult.score >= 30) {
339
+ setBadge(tabId, "warn", "?");
340
+ } else {
341
+ setBadge(tabId, "none", "");
342
+ }
343
+
344
+ await storeFeedbackRecord(url, fallback);
345
+ }
346
+
347
+ await incrementScanCount();
348
+ await checkRetrain("count");
349
+ pageSignals.delete(tabId);
350
+
351
+ }, { url: [{ schemes: ["http", "https"] }] });
352
+
353
+ // ── Feedback Record Storage ──────────────────────────────────────────
354
+ async function storeFeedbackRecord(url, result) {
355
+ const urlHash = await sha256(url);
356
+ const record = {
357
+ url,
358
+ verdict: result.is_phishing ? "phishing" : "safe",
359
+ confidence: result.confidence || 0,
360
+ tier_used: result.tier || 0,
361
+ heuristic_score: result.heuristic_score || 0,
362
+ signals: result.signals || [],
363
+ user_feedback: null,
364
+ timestamp: new Date().toISOString(),
365
+ feedback_ts: null,
366
+ url_hash: urlHash,
367
+ session_id: (await getStorage(["session_id"])).session_id || "",
368
+ };
369
+
370
+ const queue = await getQueue();
371
+ queue.push(record);
372
+ await setQueue(queue);
373
+ }
374
+
375
+ async function incrementScanCount() {
376
+ const data = await getStorage(["scan_count"]);
377
+ await setStorage({ scan_count: (data.scan_count || 0) + 1 });
378
+ }
379
+
380
+ // ── Block Page ───────────────────────────────────────────────────────
381
+ function blockPage(tabId, url, result) {
382
+ chrome.storage.local.set({ lastResult: { ...result, status: "blocked" } });
383
+ tabResultCache.set(tabId, result);
384
+ const score = Math.round((result.confidence || 0) * 100);
385
+ chrome.tabs.update(tabId, {
386
+ url: chrome.runtime.getURL("popup.html") +
387
+ "?blocked=1&url=" + encodeURIComponent(url) +
388
+ "&score=" + score +
389
+ "&method=" + encodeURIComponent(result.method || "")
390
+ });
391
+ }
392
+
393
+ // ── Retrain Check ────────────────────────────────────────────────────
394
+ async function checkRetrain(trigger = "count") {
395
+ const queue = await getQueue();
396
+ const labeled = queue.filter(r => r.user_feedback !== null);
397
+
398
+ if (labeled.length < MIN_LABELED_SAMPLES) {
399
+ console.log(`[PhishGuard] Not enough labeled samples (${labeled.length}/${MIN_LABELED_SAMPLES})`);
400
+ return;
401
+ }
402
+
403
+ const data = await getStorage(["scan_count"]);
404
+ const scanCount = data.scan_count || 0;
405
+
406
+ if (trigger === "timer" || scanCount >= RETRAIN_URL_THRESHOLD) {
407
+ console.log(`[PhishGuard] Triggering retrain: trigger=${trigger}, labeled=${labeled.length}, scans=${scanCount}`);
408
+ await sendRetrainRequest(labeled, trigger);
409
+ }
410
+ }
411
+
412
+ async function sendRetrainRequest(samples, trigger) {
413
+ const data = await getStorage(["session_id"]);
414
+ try {
415
+ const result = await fetchBackend(RETRAIN_URL, {
416
+ samples,
417
+ trigger,
418
+ session_id: data.session_id || "",
419
+ extension_version: "3.0",
420
+ });
421
+
422
+ if (result.status === "success") {
423
+ // Reset counters
424
+ await setStorage({
425
+ scan_count: 0,
426
+ feedback_count: 0,
427
+ last_retrain_ts: new Date().toISOString(),
428
+ });
429
+
430
+ // Remove sent records from queue
431
+ const queue = await getQueue();
432
+ const sentHashes = new Set(samples.map(s => s.url_hash));
433
+ const remaining = queue.filter(r => !sentHashes.has(r.url_hash));
434
+ await setQueue(remaining);
435
+
436
+ // Show notification
437
+ showRetrainNotification(result.accuracy_delta || {});
438
+
439
+ console.log("[PhishGuard] Retrain success:", result);
440
+ }
441
+ } catch (err) {
442
+ console.error("[PhishGuard] Retrain request failed:", err.message);
443
+ }
444
+ }
445
+
446
+ function showRetrainNotification(delta) {
447
+ const bertDelta = delta.bert ? `BERT: ${(delta.bert * 100).toFixed(1)}%` : "";
448
+ const gnnDelta = delta.gnn ? `GNN: ${(delta.gnn * 100).toFixed(1)}%` : "";
449
+ const parts = [bertDelta, gnnDelta].filter(Boolean).join(", ");
450
+
451
+ chrome.notifications.create("retrain_complete", {
452
+ type: "basic",
453
+ iconUrl: "icons/icon48.png",
454
+ title: "PhishGuard AI Updated",
455
+ message: parts ? `Models improved! ${parts} accuracy from your feedback` :
456
+ "Models updated with your feedback",
457
+ });
458
+ }
459
+
460
+ // ── Model Version Polling ────────────────────────────────────────────
461
+ async function pollModelVersion() {
462
+ try {
463
+ const controller = new AbortController();
464
+ const timeout = setTimeout(() => controller.abort(), 10000);
465
+ const resp = await fetch(MODEL_VERSION_URL, { signal: controller.signal });
466
+ clearTimeout(timeout);
467
+
468
+ if (!resp.ok) return;
469
+ const info = await resp.json();
470
+
471
+ const stored = await getStorage(["model_version"]);
472
+ if (info.version > (stored.model_version || 0)) {
473
+ await setStorage({ model_version: info.version });
474
+ // Clear URL cache (stale results)
475
+ urlCache.clear();
476
+
477
+ chrome.notifications.create("model_updated", {
478
+ type: "basic",
479
+ iconUrl: "icons/icon48.png",
480
+ title: "PhishGuard Models Updated",
481
+ message: `Model v${info.version} is now active`,
482
+ });
483
+ }
484
+ } catch (err) {
485
+ // Silently fail β€” model polling is best-effort
486
+ }
487
+ }
488
+
489
+ // ── Message Handler ──────────────────────────────────────────────────
490
+ chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => {
491
+ // Page signals from content.js
492
+ if (msg.type === "page_signals") {
493
+ if (sender.tab) {
494
+ pageSignals.set(sender.tab.id, {
495
+ title: msg.title || "",
496
+ snippet: msg.snippet || "",
497
+ signals: msg.signals || [],
498
+ });
499
+ }
500
+ }
501
+
502
+ // Submit feedback from popup.js / content.js
503
+ if (msg.type === "submit_feedback") {
504
+ (async () => {
505
+ const queue = await getQueue();
506
+ const idx = queue.findIndex(r => r.url_hash === msg.url_hash);
507
+ if (idx >= 0) {
508
+ queue[idx].user_feedback = msg.feedback; // "correct" or "incorrect"
509
+ queue[idx].feedback_ts = new Date().toISOString();
510
+ await setQueue(queue);
511
+
512
+ // Increment feedback count
513
+ const data = await getStorage(["feedback_count"]);
514
+ await setStorage({ feedback_count: (data.feedback_count || 0) + 1 });
515
+
516
+ // Check if we should trigger retraining
517
+ await checkRetrain("count");
518
+
519
+ sendResponse({ success: true });
520
+ } else {
521
+ sendResponse({ success: false, error: "Record not found" });
522
+ }
523
+ })();
524
+ return true; // async response
525
+ }
526
+
527
+ // Get status for popup
528
+ if (msg.type === "get_status") {
529
+ (async () => {
530
+ const data = await getStorage([
531
+ "scan_count", "feedback_count", "last_retrain_ts",
532
+ "model_version", "session_id"
533
+ ]);
534
+ const queue = await getQueue();
535
+ const labeled = queue.filter(r => r.user_feedback !== null).length;
536
+
537
+ const lastRetrain = data.last_retrain_ts ? new Date(data.last_retrain_ts) : null;
538
+ const now = Date.now();
539
+ const nextTimerMs = lastRetrain
540
+ ? Math.max(0, (24 * 60 * 60 * 1000) - (now - lastRetrain.getTime()))
541
+ : 24 * 60 * 60 * 1000;
542
+
543
+ sendResponse({
544
+ scan_count: data.scan_count || 0,
545
+ feedback_count: data.feedback_count || 0,
546
+ labeled_count: labeled,
547
+ last_retrain_ts: data.last_retrain_ts,
548
+ model_version: data.model_version || 0,
549
+ next_retrain_urls_remaining: Math.max(0, RETRAIN_URL_THRESHOLD - (data.scan_count || 0)),
550
+ next_retrain_time_remaining_ms: nextTimerMs,
551
+ min_labeled_needed: Math.max(0, MIN_LABELED_SAMPLES - labeled),
552
+ });
553
+ })();
554
+ return true;
555
+ }
556
+
557
+ // Per-tab result cache query from popup
558
+ if (msg.type === "get_tab_result") {
559
+ const result = tabResultCache.get(msg.tabId);
560
+ sendResponse({ result: result || null });
561
+ return false;
562
+ }
563
+
564
+ // User override (Proceed Anyway)
565
+ if (msg.type === "whitelist_url") {
566
+ const override = {
567
+ url: msg.url, status: "safe", tier: 0,
568
+ method: "user-override", confidence: 0
569
+ };
570
+ setCache(msg.url, override);
571
+ chrome.storage.local.set({ lastResult: override });
572
+ sendResponse({ success: true });
573
+ }
574
+
575
+ // Gmail scanner bridge
576
+ if (msg.action === "analyzeEmail") {
577
+ const emailURL = ANALYZE_URL.replace(/\/analyze\/?$/, "/analyze/email");
578
+ fetch(emailURL, {
579
+ method: "POST",
580
+ headers: { "Content-Type": "application/json" },
581
+ body: JSON.stringify(msg.data),
582
+ })
583
+ .then(r => r.ok ? r.json() : Promise.reject(new Error(`${r.status}`)))
584
+ .then(data => sendResponse(data))
585
+ .catch(err => sendResponse({
586
+ status: "error",
587
+ analysis: { isPhishing: false, probability: 0, reason: "Backend unreachable" }
588
+ }));
589
+ return true;
590
+ }
591
+ });
592
+
593
+ // ── Tab cleanup ──────────────────────────────────────────────────────
594
+ chrome.tabs.onRemoved.addListener(tabId => {
595
+ pageSignals.delete(tabId);
596
+ tabResultCache.delete(tabId);
597
+ });
598
+
599
+ chrome.tabs.onUpdated.addListener((tabId, changeInfo) => {
600
+ if (changeInfo.url) {
601
+ tabResultCache.delete(tabId);
602
+ setBadge(tabId, "none", "");
603
+ }
604
+ });
bert_analyzer.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - bert_analyzer.py
3
+ # Tier 3a: BERT NLP Phishing Classifier
4
+ #
5
+ # Model: ealvaradob/bert-finetuned-phishing (HuggingFace Hub)
6
+ # Tokenization: split on [-./=?&_~%@] to preserve homoglyphs
7
+ # Input: "URL: {tokenized_url}. Title: {title}. Content: {snippet}"
8
+ # Output: P_bert ∈ [0,1]
9
+ # Supports: load, predict, fine-tune, incremental_update, save/load
10
+ # ============================================================
11
+
12
+ from __future__ import annotations
13
+
14
+ import re
15
+ import math
16
+ import logging
17
+ import threading
18
+ from pathlib import Path
19
+ from typing import List, Tuple, Optional, Dict
20
+
21
+ logger = logging.getLogger("phishguard.bert")
22
+
23
+ # ── Model state ──────────────────────────────────────────────────────
24
+ _classifier = None
25
+ _tokenizer = None
26
+ _model = None
27
+ _use_bert: bool = False
28
+ _bert_load_attempted: bool = False
29
+ _bert_lock = threading.Lock()
30
+
31
+ # Check if transformers library is installed
32
+ _transformers_available: bool = False
33
+ try:
34
+ import transformers as _tf_module
35
+ _transformers_available = True
36
+ logger.info("transformers library found β€” BERT will lazy-load on first call")
37
+ except ImportError:
38
+ logger.info("transformers not installed β€” using keyword NLP fallback")
39
+
40
+
41
+ # ── Phishing pattern databases (for keyword fallback) ────────────────
42
+ PHISHING_TERMS = [
43
+ "verify your account", "suspended", "click here immediately",
44
+ "unusual activity", "confirm your identity", "limited time",
45
+ "your password has been", "unauthorized access", "act now",
46
+ "secure your account", "login credentials", "reset password immediately",
47
+ "your account will be", "verify your identity", "we noticed suspicious",
48
+ ]
49
+
50
+ PHISHING_KEYWORDS = [
51
+ "login", "secure", "verify", "account", "update", "confirm",
52
+ "banking", "paypal", "signin", "password", "suspend", "alert",
53
+ "restore", "unusual", "limited", "expire", "urgent", "immediately",
54
+ ]
55
+
56
+ BRAND_NAMES = [
57
+ "paypal", "google", "apple", "microsoft", "amazon", "netflix",
58
+ "facebook", "instagram", "twitter", "linkedin", "chase", "wells",
59
+ "bankofamerica", "citibank", "usps", "fedex", "ebay",
60
+ ]
61
+
62
+
63
+ class BERTPhishingClassifier:
64
+ """
65
+ BERT-based phishing text classifier.
66
+ Wraps HuggingFace model with URL-aware tokenization.
67
+ """
68
+
69
+ DEFAULT_MODEL = "ealvaradob/bert-finetuned-phishing"
70
+ FALLBACK_MODEL = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
71
+
72
+ def __init__(self, model_name: Optional[str] = None) -> None:
73
+ self.model_name: str = model_name or self.DEFAULT_MODEL
74
+ self._pipeline = None
75
+ self._tokenizer = None
76
+ self._model = None
77
+ self._loaded: bool = False
78
+ self._lock = threading.Lock()
79
+ self._re_url_split = re.compile(r"[-./=?&_~%@:]+")
80
+
81
+ def load_model(self) -> None:
82
+ """Load BERT model from HuggingFace Hub with cache fallback."""
83
+ if self._loaded:
84
+ return
85
+ with self._lock:
86
+ if self._loaded:
87
+ return
88
+ if not _transformers_available:
89
+ logger.warning("transformers not available, BERT disabled")
90
+ return
91
+ try:
92
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
93
+
94
+ # Try primary model, fall back to smaller model
95
+ for model_id in [self.model_name, self.FALLBACK_MODEL]:
96
+ try:
97
+ self._pipeline = pipeline(
98
+ "text-classification",
99
+ model=model_id,
100
+ truncation=True,
101
+ max_length=512,
102
+ device=-1,
103
+ )
104
+ self._tokenizer = AutoTokenizer.from_pretrained(model_id)
105
+ self._model = AutoModelForSequenceClassification.from_pretrained(model_id)
106
+ self.model_name = model_id
107
+ self._loaded = True
108
+ logger.info(f"BERT model loaded: {model_id}")
109
+ return
110
+ except Exception as e:
111
+ logger.warning(f"Failed to load {model_id}: {e}")
112
+ continue
113
+
114
+ logger.error("All BERT model candidates failed")
115
+
116
+ except Exception as e:
117
+ logger.error(f"BERT initialization failed: {e}")
118
+
119
+ def tokenize_url(self, url: str) -> str:
120
+ """
121
+ Split URL on [-./=?&_~%@:] to preserve homoglyphs.
122
+ Example: "paypa1-l0gin.xyz/verify" β†’ "paypa1 l0gin xyz verify"
123
+ """
124
+ text = url.replace("https://", "").replace("http://", "")
125
+ tokens = self._re_url_split.split(text)
126
+ return " ".join(t for t in tokens if t)
127
+
128
+ def predict(self, url: str, title: str = "", snippet: str = "") -> float:
129
+ """
130
+ Predict phishing probability for a URL + page context.
131
+ Returns P_bert ∈ [0,1].
132
+ """
133
+ self.load_model()
134
+
135
+ if self._loaded and self._pipeline is not None:
136
+ return self._predict_bert(url, title, snippet)
137
+ return self._predict_keyword(url, title, snippet)
138
+
139
+ def _predict_bert(self, url: str, title: str, snippet: str) -> float:
140
+ """BERT model prediction path."""
141
+ url_text = self.tokenize_url(url)
142
+ combined = f"URL: {url_text}. Title: {title}. Content: {snippet[:300]}"
143
+
144
+ result = self._pipeline(combined[:512])[0]
145
+ label = result["label"].upper()
146
+ confidence = result["score"]
147
+
148
+ # Map label to phishing probability
149
+ if any(kw in label for kw in ["SPAM", "PHISH", "MALICIOUS", "LABEL_1", "1"]):
150
+ raw_prob = confidence
151
+ else:
152
+ raw_prob = 1.0 - confidence
153
+
154
+ # Boost with keyword signals
155
+ text_lower = combined.lower()
156
+ phrase_hits = sum(1 for p in PHISHING_TERMS if p in text_lower)
157
+ adjusted = min(raw_prob + (phrase_hits * 0.05), 1.0)
158
+
159
+ return round(adjusted, 4)
160
+
161
+ def _predict_keyword(self, url: str, title: str, snippet: str) -> float:
162
+ """Keyword-based fallback when BERT is unavailable."""
163
+ combined = f"{url} {title} {snippet}".lower()
164
+ url_lower = url.lower()
165
+ score = 0.0
166
+
167
+ # Keyword hits in URL
168
+ kw_hits = sum(1 for kw in PHISHING_KEYWORDS if kw in url_lower)
169
+ score += min(kw_hits * 0.08, 0.40)
170
+
171
+ # Phrase matches in content
172
+ phrase_hits = sum(1 for p in PHISHING_TERMS if p in combined)
173
+ score += min(phrase_hits * 0.12, 0.48)
174
+
175
+ # Brand spoofing
176
+ for brand in BRAND_NAMES:
177
+ if brand in url_lower:
178
+ if f"{brand}.com" not in url_lower:
179
+ score += 0.20
180
+ break
181
+
182
+ # IP as hostname
183
+ if re.match(r"https?://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", url):
184
+ score += 0.20
185
+
186
+ # Shannon entropy of hostname
187
+ try:
188
+ from urllib.parse import urlparse
189
+ host = urlparse(url if "://" in url else f"http://{url}").hostname or ""
190
+ if host:
191
+ length = len(host)
192
+ freq: Dict[str, int] = {}
193
+ for c in host:
194
+ freq[c] = freq.get(c, 0) + 1
195
+ entropy = -sum(
196
+ (cnt / length) * math.log2(cnt / length) for cnt in freq.values()
197
+ )
198
+ if entropy > 3.5:
199
+ score += 0.10
200
+ except Exception:
201
+ pass
202
+
203
+ return round(min(score, 1.0), 4)
204
+
205
+ def incremental_update(
206
+ self,
207
+ samples: List[Tuple[str, int]],
208
+ lr: float = 1e-5,
209
+ epochs: int = 1,
210
+ label_smoothing: float = 0.1,
211
+ ) -> Optional[float]:
212
+ """
213
+ Incremental update: unfreeze last 2 transformer layers only.
214
+ Returns accuracy_delta (float) or None if update failed.
215
+
216
+ samples: list of (url, label) where label is 0 or 1
217
+ """
218
+ if not self._loaded or self._model is None or self._tokenizer is None:
219
+ logger.warning("BERT not loaded, cannot incrementally update")
220
+ return None
221
+
222
+ if len(samples) < 5:
223
+ logger.warning(f"Too few samples ({len(samples)}) for BERT update")
224
+ return None
225
+
226
+ try:
227
+ import torch
228
+ from torch.utils.data import DataLoader, TensorDataset
229
+ from torch.optim import AdamW
230
+
231
+ device = torch.device("cpu")
232
+ model = self._model.to(device)
233
+
234
+ # Freeze all layers
235
+ for param in model.parameters():
236
+ param.requires_grad = False
237
+
238
+ # Unfreeze last 2 transformer layers + classifier
239
+ if hasattr(model, "bert"):
240
+ encoder_layers = model.bert.encoder.layer
241
+ for layer in encoder_layers[-2:]:
242
+ for param in layer.parameters():
243
+ param.requires_grad = True
244
+ if hasattr(model, "classifier"):
245
+ for param in model.classifier.parameters():
246
+ param.requires_grad = True
247
+
248
+ # Prepare data
249
+ texts = [self.tokenize_url(url) for url, _ in samples]
250
+ labels = [label for _, label in samples]
251
+
252
+ encodings = self._tokenizer(
253
+ texts, truncation=True, padding=True, max_length=512,
254
+ return_tensors="pt"
255
+ )
256
+ label_tensor = torch.tensor(labels, dtype=torch.long).to(device)
257
+
258
+ dataset = TensorDataset(
259
+ encodings["input_ids"].to(device),
260
+ encodings["attention_mask"].to(device),
261
+ label_tensor,
262
+ )
263
+ batch_size = min(len(samples), 16)
264
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
265
+
266
+ # Pre-update accuracy
267
+ model.eval()
268
+ with torch.no_grad():
269
+ pre_correct = 0
270
+ for batch in loader:
271
+ ids, mask, labs = batch
272
+ outputs = model(input_ids=ids, attention_mask=mask)
273
+ preds = torch.argmax(outputs.logits, dim=1)
274
+ pre_correct += (preds == labs).sum().item()
275
+ pre_acc = pre_correct / len(samples)
276
+
277
+ # Train
278
+ optimizer = AdamW(
279
+ filter(lambda p: p.requires_grad, model.parameters()),
280
+ lr=lr,
281
+ )
282
+ loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
283
+
284
+ model.train()
285
+ for epoch in range(epochs):
286
+ total_loss = 0.0
287
+ for batch in loader:
288
+ ids, mask, labs = batch
289
+ optimizer.zero_grad()
290
+ outputs = model(input_ids=ids, attention_mask=mask)
291
+ loss = loss_fn(outputs.logits, labs)
292
+ loss.backward()
293
+ optimizer.step()
294
+ total_loss += loss.item()
295
+ logger.info(f"BERT incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}")
296
+
297
+ # Post-update accuracy
298
+ model.eval()
299
+ with torch.no_grad():
300
+ post_correct = 0
301
+ for batch in loader:
302
+ ids, mask, labs = batch
303
+ outputs = model(input_ids=ids, attention_mask=mask)
304
+ preds = torch.argmax(outputs.logits, dim=1)
305
+ post_correct += (preds == labs).sum().item()
306
+ post_acc = post_correct / len(samples)
307
+
308
+ delta = post_acc - pre_acc
309
+ self._model = model
310
+ logger.info(f"BERT incremental update: {pre_acc:.4f} β†’ {post_acc:.4f} (Ξ”={delta:+.4f})")
311
+ return round(delta, 4)
312
+
313
+ except Exception as e:
314
+ logger.error(f"BERT incremental update failed: {e}")
315
+ return None
316
+
317
+ def save(self, path: Path) -> None:
318
+ """Save model and tokenizer to directory."""
319
+ if self._model and self._tokenizer:
320
+ path = Path(path)
321
+ path.mkdir(parents=True, exist_ok=True)
322
+ self._model.save_pretrained(str(path))
323
+ self._tokenizer.save_pretrained(str(path))
324
+ logger.info(f"BERT model saved to {path}")
325
+
326
+ def load_local(self, path: Path) -> bool:
327
+ """Load model from local directory."""
328
+ path = Path(path)
329
+ if not path.exists():
330
+ return False
331
+ try:
332
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
333
+ self._tokenizer = AutoTokenizer.from_pretrained(str(path))
334
+ self._model = AutoModelForSequenceClassification.from_pretrained(str(path))
335
+ self._pipeline = pipeline(
336
+ "text-classification",
337
+ model=self._model,
338
+ tokenizer=self._tokenizer,
339
+ truncation=True,
340
+ max_length=512,
341
+ device=-1,
342
+ )
343
+ self._loaded = True
344
+ logger.info(f"BERT model loaded from {path}")
345
+ return True
346
+ except Exception as e:
347
+ logger.error(f"BERT local load failed: {e}")
348
+ return False
349
+
350
+ @property
351
+ def is_loaded(self) -> bool:
352
+ return self._loaded
353
+
354
+
355
+ # ── Legacy compatibility ─────────────────────────────────────────────
356
+ _default_classifier = BERTPhishingClassifier()
357
+
358
+
359
+ def analyze_text(url: str, page_title: str = "", page_snippet: str = "") -> dict:
360
+ """Legacy wrapper for backward compatibility with main.py."""
361
+ prob = _default_classifier.predict(url, page_title, page_snippet)
362
+ return {
363
+ "bert_phishing_prob": prob,
364
+ "phrase_hits": 0,
365
+ "label": "BERT" if _default_classifier.is_loaded else "KEYWORD_NLP",
366
+ "confidence": prob,
367
+ }
368
+
369
+
370
+ def shannon_entropy(s: str) -> float:
371
+ """Utility: measure randomness of a string."""
372
+ if not s:
373
+ return 0.0
374
+ prob = [s.count(c) / len(s) for c in set(s)]
375
+ return -sum(p * math.log2(p) for p in prob if p > 0)
bert_finetune.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - bert_finetune.py
3
+ # Full BERT fine-tuning script on PhishTank + TRANCO data
4
+ #
5
+ # Downloads data, fine-tunes ealvaradob/bert-finetuned-phishing
6
+ # 3 epochs, AdamW + linear warmup scheduler
7
+ # Saves to bert_weights/ with save_pretrained()
8
+ # Prints per-epoch: loss / precision / recall / F1
9
+ # ============================================================
10
+
11
+ from __future__ import annotations
12
+
13
+ import logging
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import List, Tuple
17
+
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format="%(asctime)s | %(levelname)-7s | %(message)s",
21
+ )
22
+ logger = logging.getLogger("phishguard.bert_finetune")
23
+
24
+ BASE_DIR = Path(__file__).parent
25
+ BERT_WEIGHTS_DIR = BASE_DIR / "bert_weights"
26
+
27
+
28
+ def main() -> None:
29
+ """Fine-tune BERT on PhishTank + TRANCO URLs."""
30
+ print("=" * 60)
31
+ print("PhishGuard AI β€” BERT Fine-Tuning")
32
+ print("=" * 60)
33
+
34
+ # ── Check dependencies ───────────────────────────────────────
35
+ try:
36
+ import torch
37
+ from torch.utils.data import DataLoader, Dataset
38
+ from torch.optim import AdamW
39
+ from transformers import (
40
+ AutoTokenizer,
41
+ AutoModelForSequenceClassification,
42
+ get_linear_schedule_with_warmup,
43
+ )
44
+ from sklearn.metrics import precision_recall_fscore_support
45
+ except ImportError as e:
46
+ print(f"❌ Missing dependency: {e}")
47
+ print(" Run: pip install torch transformers scikit-learn")
48
+ sys.exit(1)
49
+
50
+ # ── Download data ────────────────────────────────────────────
51
+ from data_collector import download_phishtank, download_tranco, merge_datasets
52
+
53
+ print("\nπŸ“₯ Downloading datasets...")
54
+ phish_urls = download_phishtank(max_urls=50)
55
+ legit_urls = download_tranco(n=50)
56
+ print(f" Phishing URLs: {len(phish_urls)}")
57
+ print(f" Legitimate URLs: {len(legit_urls)}")
58
+
59
+ train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls)
60
+
61
+ # ── URL tokenization ─────────────────────────────────────────
62
+ import re
63
+ _re_url_split = re.compile(r"[-./=?&_~%@:]+")
64
+
65
+ def tokenize_url(url: str) -> str:
66
+ text = url.replace("https://", "").replace("http://", "")
67
+ tokens = _re_url_split.split(text)
68
+ return " ".join(t for t in tokens if t)
69
+
70
+ # ── Dataset class ────────────────────────────────────────────
71
+ class PhishingURLDataset(Dataset):
72
+ def __init__(self, data: List[Tuple[str, int]], tokenizer, max_length: int = 512):
73
+ self.data = data
74
+ self.tokenizer = tokenizer
75
+ self.max_length = max_length
76
+
77
+ def __len__(self) -> int:
78
+ return len(self.data)
79
+
80
+ def __getitem__(self, idx: int):
81
+ url, label = self.data[idx]
82
+ text = f"URL: {tokenize_url(url)}"
83
+ encoding = self.tokenizer(
84
+ text,
85
+ truncation=True,
86
+ padding="max_length",
87
+ max_length=self.max_length,
88
+ return_tensors="pt",
89
+ )
90
+ return {
91
+ "input_ids": encoding["input_ids"].squeeze(0),
92
+ "attention_mask": encoding["attention_mask"].squeeze(0),
93
+ "labels": torch.tensor(label, dtype=torch.long),
94
+ }
95
+
96
+ # ── Load model ───────────────────────────────────────────────
97
+ MODEL_NAME = "ealvaradob/bert-finetuned-phishing"
98
+ FALLBACK = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
99
+
100
+ print("\nπŸ€– Loading BERT model...")
101
+ tokenizer = None
102
+ model = None
103
+ for model_id in [MODEL_NAME, FALLBACK]:
104
+ try:
105
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
106
+ model = AutoModelForSequenceClassification.from_pretrained(
107
+ model_id, num_labels=2
108
+ )
109
+ print(f" βœ… Loaded: {model_id}")
110
+ break
111
+ except Exception as e:
112
+ print(f" ⚠️ {model_id} failed: {e}")
113
+ continue
114
+
115
+ if model is None or tokenizer is None:
116
+ print("❌ Could not load any BERT model. Exiting.")
117
+ sys.exit(1)
118
+
119
+ # ── Prepare data ─────────────────────────────────────────────
120
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
+ print(f" Device: {device}")
122
+
123
+ train_dataset = PhishingURLDataset(train_data, tokenizer)
124
+ val_dataset = PhishingURLDataset(val_data, tokenizer)
125
+
126
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
127
+ val_loader = DataLoader(val_dataset, batch_size=32)
128
+
129
+ model = model.to(device)
130
+
131
+ # ── Optimizer + Scheduler ────────────────────────────────────
132
+ EPOCHS = 1
133
+ optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
134
+ total_steps = len(train_loader) * EPOCHS
135
+ scheduler = get_linear_schedule_with_warmup(
136
+ optimizer,
137
+ num_warmup_steps=total_steps // 10,
138
+ num_training_steps=total_steps,
139
+ )
140
+
141
+ # ── Training Loop ────────────────────────────────────────────
142
+ print(f"\nπŸ‹οΈ Training for {EPOCHS} epochs...")
143
+ print(f" Train batches: {len(train_loader)}")
144
+ print(f" Val batches: {len(val_loader)}")
145
+
146
+ best_f1 = 0.0
147
+ for epoch in range(1, EPOCHS + 1):
148
+ # Train
149
+ model.train()
150
+ total_loss = 0.0
151
+ train_preds = []
152
+ train_labels = []
153
+
154
+ for batch_idx, batch in enumerate(train_loader):
155
+ input_ids = batch["input_ids"].to(device)
156
+ attention_mask = batch["attention_mask"].to(device)
157
+ labels = batch["labels"].to(device)
158
+
159
+ optimizer.zero_grad()
160
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
161
+ loss = outputs.loss
162
+ loss.backward()
163
+
164
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
165
+ optimizer.step()
166
+ scheduler.step()
167
+
168
+ total_loss += loss.item()
169
+ preds = torch.argmax(outputs.logits, dim=1)
170
+ train_preds.extend(preds.cpu().tolist())
171
+ train_labels.extend(labels.cpu().tolist())
172
+
173
+ if (batch_idx + 1) % 50 == 0:
174
+ print(f" Epoch {epoch} | Batch {batch_idx+1}/{len(train_loader)} | Loss: {loss.item():.4f}")
175
+
176
+ avg_loss = total_loss / len(train_loader)
177
+
178
+ # Validate
179
+ model.eval()
180
+ val_preds = []
181
+ val_labels = []
182
+ with torch.no_grad():
183
+ for batch in val_loader:
184
+ input_ids = batch["input_ids"].to(device)
185
+ attention_mask = batch["attention_mask"].to(device)
186
+ labels = batch["labels"].to(device)
187
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
188
+ preds = torch.argmax(outputs.logits, dim=1)
189
+ val_preds.extend(preds.cpu().tolist())
190
+ val_labels.extend(labels.cpu().tolist())
191
+
192
+ precision, recall, f1, _ = precision_recall_fscore_support(
193
+ val_labels, val_preds, average="binary", zero_division=0
194
+ )
195
+
196
+ print(f"\n πŸ“Š Epoch {epoch}/{EPOCHS}:")
197
+ print(f" Loss: {avg_loss:.4f}")
198
+ print(f" Precision: {precision:.4f}")
199
+ print(f" Recall: {recall:.4f}")
200
+ print(f" F1 Score: {f1:.4f}")
201
+
202
+ # Save best model
203
+ if f1 > best_f1:
204
+ best_f1 = f1
205
+ BERT_WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
206
+ model.save_pretrained(str(BERT_WEIGHTS_DIR))
207
+ tokenizer.save_pretrained(str(BERT_WEIGHTS_DIR))
208
+ print(f" βœ… New best model saved to {BERT_WEIGHTS_DIR}")
209
+
210
+ print(f"\n🎯 Best F1: {best_f1:.4f}")
211
+ print(f"βœ… Fine-tuning complete. Weights saved to: {BERT_WEIGHTS_DIR}")
212
+ print("=" * 60)
213
+
214
+
215
+ if __name__ == "__main__":
216
+ main()
cnn_inference.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - cnn/cnn_inference.py
3
+ # CNN inference wrapper for Tier 4 visual analysis.
4
+ # Supports: predict, hot-reload, incremental_update.
5
+ # ============================================================
6
+
7
+ from __future__ import annotations
8
+
9
+ import io
10
+ import random
11
+ import logging
12
+ from pathlib import Path
13
+ from typing import List, Optional, Tuple
14
+
15
+ import torch
16
+ from PIL import Image
17
+
18
+ logger = logging.getLogger("phishguard.cnn.inference")
19
+
20
+ CNN_DIR = Path(__file__).parent
21
+ BACKEND_DIR = CNN_DIR.parent
22
+ WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt"
23
+ REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt"
24
+
25
+
26
+ class CNNInference:
27
+ """CNN inference wrapper with hot-reload and incremental update."""
28
+
29
+ def __init__(self, weights_path: Optional[Path] = None) -> None:
30
+ self._weights_path = weights_path or WEIGHTS_PATH
31
+ self._model = None
32
+ self._loaded = False
33
+
34
+ def load(self, weights_path: Optional[Path] = None) -> bool:
35
+ """Load CNN model."""
36
+ from cnn_model import load_cnn
37
+
38
+ path = weights_path or self._weights_path
39
+ self._model = load_cnn(str(path) if path.exists() else None)
40
+ self._loaded = self._model is not None
41
+ return self._loaded
42
+
43
+ def predict(self, screenshot_bytes: bytes) -> float:
44
+ """
45
+ Predict phishing probability from screenshot bytes.
46
+ Returns P_cnn ∈ [0,1].
47
+ """
48
+ if not self._loaded:
49
+ self.load()
50
+
51
+ if self._model is None:
52
+ return 0.5
53
+
54
+ from cnn_model import preprocess_screenshot
55
+
56
+ try:
57
+ tensor = preprocess_screenshot(screenshot_bytes)
58
+ return self._model.predict_proba(tensor)
59
+ except Exception as e:
60
+ logger.error(f"CNN predict failed: {e}")
61
+ return 0.5
62
+
63
+ def reload(self, weights_path: Optional[Path] = None) -> bool:
64
+ """Hot-reload model with new weights."""
65
+ from cnn_model import load_cnn
66
+
67
+ path = weights_path or self._weights_path
68
+ new_model = load_cnn(str(path))
69
+ if new_model is not None:
70
+ self._model = new_model
71
+ self._loaded = True
72
+ logger.info(f"CNN hot-reloaded from {path}")
73
+ return True
74
+ return False
75
+
76
+ async def incremental_update(
77
+ self,
78
+ tier4_samples: List[Tuple[str, int]],
79
+ replay_buffer_path: Optional[Path] = None,
80
+ lr: float = 1e-4,
81
+ epochs: int = 3,
82
+ ) -> Optional[float]:
83
+ """
84
+ Incremental update on Tier 4 feedback samples.
85
+ Re-captures screenshots via Playwright, trains on them + replay buffer.
86
+ Returns accuracy_delta or None if no Tier 4 samples.
87
+ """
88
+ if not tier4_samples:
89
+ logger.info("No Tier 4 samples β€” skipping CNN update")
90
+ return None
91
+
92
+ if self._model is None:
93
+ logger.warning("CNN not loaded, cannot update")
94
+ return None
95
+
96
+ try:
97
+ import torch.nn as nn
98
+ from torch.optim import AdamW
99
+ from torch.utils.data import DataLoader, TensorDataset
100
+ import torchvision.transforms as T
101
+
102
+ device = torch.device("cpu")
103
+ model = self._model.to(device)
104
+
105
+ transform = T.Compose([
106
+ T.Resize((224, 224)),
107
+ T.ToTensor(),
108
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
109
+ ])
110
+
111
+ # Try to capture screenshots for the new samples
112
+ tensors = []
113
+ labels = []
114
+
115
+ for url, label in tier4_samples:
116
+ try:
117
+ # Try to capture screenshot
118
+ screenshot_bytes = await self._capture_screenshot(url)
119
+ if screenshot_bytes:
120
+ img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB")
121
+ tensor = transform(img)
122
+ tensors.append(tensor)
123
+ labels.append(float(label))
124
+ except Exception as e:
125
+ logger.warning(f"Screenshot capture failed for {url}: {e}")
126
+ continue
127
+
128
+ # Load replay buffer (20% mix)
129
+ buf_path = replay_buffer_path or REPLAY_BUFFER_PATH
130
+ if buf_path.exists():
131
+ try:
132
+ buf_data = torch.load(buf_path, map_location="cpu", weights_only=False)
133
+ buf_paths = buf_data.get("paths", [])
134
+ buf_labels = buf_data.get("labels", [])
135
+
136
+ replay_count = max(1, len(buf_paths) // 5)
137
+ indices = random.sample(range(len(buf_paths)), min(replay_count, len(buf_paths)))
138
+
139
+ for idx in indices:
140
+ try:
141
+ img = Image.open(buf_paths[idx]).convert("RGB")
142
+ tensor = transform(img)
143
+ tensors.append(tensor)
144
+ labels.append(float(buf_labels[idx]))
145
+ except Exception:
146
+ continue
147
+ except Exception as e:
148
+ logger.warning(f"CNN replay buffer load failed: {e}")
149
+
150
+ if len(tensors) < 5:
151
+ logger.warning(f"Too few CNN samples ({len(tensors)}), skipping update")
152
+ return None
153
+
154
+ # Stack and create dataset
155
+ x_data = torch.stack(tensors)
156
+ y_data = torch.tensor(labels, dtype=torch.float)
157
+ dataset = TensorDataset(x_data, y_data)
158
+ loader = DataLoader(dataset, batch_size=8, shuffle=True)
159
+
160
+ # Pre-update accuracy
161
+ model.eval()
162
+ pre_correct = 0
163
+ with torch.no_grad():
164
+ for bx, by in loader:
165
+ bx, by = bx.to(device), by.to(device)
166
+ out = model(bx).squeeze()
167
+ preds = (out >= 0.5).float()
168
+ pre_correct += (preds == by).sum().item()
169
+ pre_acc = pre_correct / len(dataset)
170
+
171
+ # Train (head only β€” backbone stays frozen)
172
+ head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad]
173
+ optimizer = AdamW(head_params, lr=lr)
174
+ loss_fn = nn.BCELoss()
175
+
176
+ model.train()
177
+ for epoch in range(epochs):
178
+ total_loss = 0.0
179
+ for bx, by in loader:
180
+ bx, by = bx.to(device), by.to(device)
181
+ optimizer.zero_grad()
182
+ out = model(bx).squeeze()
183
+ loss = loss_fn(out, by)
184
+ loss.backward()
185
+ optimizer.step()
186
+ total_loss += loss.item()
187
+ logger.info(f"CNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}")
188
+
189
+ # Post-update accuracy
190
+ model.eval()
191
+ post_correct = 0
192
+ with torch.no_grad():
193
+ for bx, by in loader:
194
+ bx, by = bx.to(device), by.to(device)
195
+ out = model(bx).squeeze()
196
+ preds = (out >= 0.5).float()
197
+ post_correct += (preds == by).sum().item()
198
+ post_acc = post_correct / len(dataset)
199
+
200
+ delta = post_acc - pre_acc
201
+ self._model = model
202
+
203
+ # Save weights
204
+ torch.save(model.state_dict(), self._weights_path)
205
+ logger.info(f"CNN incremental: {pre_acc:.4f} β†’ {post_acc:.4f} (Ξ”={delta:+.4f})")
206
+
207
+ return round(delta, 4)
208
+
209
+ except Exception as e:
210
+ logger.error(f"CNN incremental update failed: {e}")
211
+ return None
212
+
213
+ async def _capture_screenshot(self, url: str) -> Optional[bytes]:
214
+ """Capture a screenshot of a URL using Playwright."""
215
+ try:
216
+ from playwright.async_api import async_playwright
217
+
218
+ async with async_playwright() as p:
219
+ browser = await p.chromium.launch(headless=True)
220
+ page = await browser.new_page(viewport={"width": 1280, "height": 800})
221
+
222
+ # Block heavy resources
223
+ await page.route("**/*.{png,jpg,jpeg,gif,svg,mp4,webm,ogg,woff,woff2,ttf,eot}",
224
+ lambda route: route.abort())
225
+
226
+ await page.goto(url, wait_until="domcontentloaded", timeout=10000)
227
+ screenshot = await page.screenshot(type="png")
228
+ await browser.close()
229
+ return screenshot
230
+
231
+ except Exception as e:
232
+ logger.warning(f"Screenshot capture failed: {e}")
233
+ return None
234
+
235
+ @property
236
+ def is_loaded(self) -> bool:
237
+ return self._loaded
cnn_model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - cnn/cnn_model.py
3
+ # ResNet50 visual classifier for phishing screenshot detection.
4
+ #
5
+ # Architecture (from spec):
6
+ # Backbone: ResNet50 fully frozen
7
+ # Custom head: Linear(2048β†’512) β†’ ReLU β†’ Dropout(0.5) β†’
8
+ # Linear(512β†’1) β†’ Sigmoid
9
+ # Input: 224Γ—224 screenshot tensor
10
+ # Output: P_cnn ∈ [0,1]
11
+ # ============================================================
12
+
13
+ from __future__ import annotations
14
+
15
+ import io
16
+ import logging
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torchvision.models as models
22
+ import torchvision.transforms as T
23
+ from PIL import Image
24
+
25
+ logger = logging.getLogger("phishguard.cnn.model")
26
+
27
+
28
+ class PhishCNN(nn.Module):
29
+ """
30
+ ResNet50 with frozen backbone and custom 2-layer binary classification head.
31
+ Output: P_cnn ∈ [0,1] via sigmoid.
32
+ """
33
+
34
+ def __init__(self, pretrained: bool = True) -> None:
35
+ super().__init__()
36
+
37
+ # Load pretrained ResNet50 backbone
38
+ if pretrained:
39
+ self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
40
+ else:
41
+ self.backbone = models.resnet50(weights=None)
42
+
43
+ # Freeze entire backbone
44
+ for param in self.backbone.parameters():
45
+ param.requires_grad = False
46
+
47
+ # Replace fc with custom head: 2048 β†’ 512 β†’ 1 β†’ sigmoid
48
+ in_features = self.backbone.fc.in_features # 2048
49
+ self.backbone.fc = nn.Sequential(
50
+ nn.Linear(in_features, 512),
51
+ nn.ReLU(),
52
+ nn.Dropout(0.5),
53
+ nn.Linear(512, 1),
54
+ )
55
+
56
+ # Ensure custom head is trainable
57
+ for param in self.backbone.fc.parameters():
58
+ param.requires_grad = True
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Forward pass.
63
+ Input: (batch, 3, 224, 224)
64
+ Output: (batch, 1) probabilities in [0, 1]
65
+ """
66
+ logits = self.backbone(x)
67
+ return torch.sigmoid(logits)
68
+
69
+ def predict_proba(self, x: torch.Tensor) -> float:
70
+ """Return P_cnn ∈ [0,1] β€” probability of phishing."""
71
+ self.eval()
72
+ with torch.no_grad():
73
+ output = self.forward(x)
74
+ return output.squeeze().item()
75
+
76
+
77
+ # ── Preprocessing pipeline (matches ImageNet normalization) ──────────
78
+ TRANSFORM = T.Compose([
79
+ T.Resize((224, 224)),
80
+ T.ToTensor(),
81
+ T.Normalize(
82
+ mean=[0.485, 0.456, 0.406], # ImageNet mean
83
+ std=[0.229, 0.224, 0.225], # ImageNet std
84
+ ),
85
+ ])
86
+
87
+ # Training augmentation transforms
88
+ TRAIN_TRANSFORM = T.Compose([
89
+ T.Resize((224, 224)),
90
+ T.RandomHorizontalFlip(),
91
+ T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
92
+ T.RandomRotation(5),
93
+ T.ToTensor(),
94
+ T.Normalize(
95
+ mean=[0.485, 0.456, 0.406],
96
+ std=[0.229, 0.224, 0.225],
97
+ ),
98
+ ])
99
+
100
+
101
+ def preprocess_screenshot(screenshot_bytes: bytes) -> torch.Tensor:
102
+ """Convert raw screenshot bytes β†’ model-ready tensor [1, 3, 224, 224]."""
103
+ img = Image.open(io.BytesIO(screenshot_bytes)).convert("RGB")
104
+ return TRANSFORM(img).unsqueeze(0)
105
+
106
+
107
+ def load_cnn(weights_path: Optional[str] = None) -> PhishCNN:
108
+ """Load CNN model with optional trained weights."""
109
+ model = PhishCNN(pretrained=True)
110
+
111
+ if weights_path:
112
+ try:
113
+ state = torch.load(weights_path, map_location="cpu", weights_only=True)
114
+ model.load_state_dict(state)
115
+ logger.info(f"CNN weights loaded from {weights_path}")
116
+ except Exception as e:
117
+ logger.warning(f"Could not load CNN weights: {e}")
118
+ logger.info("Using ImageNet features only (baseline)")
119
+
120
+ model.eval()
121
+ return model
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ealvaradob/bert-finetuned-phishing",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "id2label": {
13
+ "0": "benign",
14
+ "1": "phishing"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 4096,
18
+ "label2id": {
19
+ "benign": 0,
20
+ "phishing": 1
21
+ },
22
+ "layer_norm_eps": 1e-12,
23
+ "max_position_embeddings": 512,
24
+ "model_type": "bert",
25
+ "num_attention_heads": 16,
26
+ "num_hidden_layers": 24,
27
+ "pad_token_id": 0,
28
+ "position_embedding_type": "absolute",
29
+ "problem_type": "single_label_classification",
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.40.0",
32
+ "type_vocab_size": 2,
33
+ "use_cache": true,
34
+ "vocab_size": 30522
35
+ }
content.js ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ============================================================
2
+ // PhishGuard AI - content.js
3
+ // Content script: runs inside every page.
4
+ // Detects phishing signals and injects feedback banner.
5
+ // ============================================================
6
+
7
+ (function() {
8
+ "use strict";
9
+
10
+ // ── Page Signal Detection ──────────────────────────────────────
11
+ function detectPageSignals() {
12
+ const signals = [];
13
+ const title = document.title || "";
14
+ const bodyText = (document.body?.innerText || "").substring(0, 2000).toLowerCase();
15
+ const url = window.location.href.toLowerCase();
16
+
17
+ // 1. Password form posting to external domain
18
+ const forms = document.querySelectorAll("form");
19
+ forms.forEach(form => {
20
+ const hasPassword = form.querySelector('input[type="password"]');
21
+ const action = (form.getAttribute("action") || "").toLowerCase();
22
+ if (hasPassword && action.startsWith("http") && !action.includes(window.location.hostname)) {
23
+ signals.push("password_form_external_action");
24
+ }
25
+ });
26
+
27
+ // 2. Brand name in title mismatching hostname
28
+ const brands = ["paypal","google","apple","microsoft","amazon","netflix",
29
+ "facebook","instagram","chase","wellsfargo","bankofamerica"];
30
+ const hostname = window.location.hostname.toLowerCase();
31
+ for (const brand of brands) {
32
+ if (title.toLowerCase().includes(brand) && !hostname.includes(brand)) {
33
+ signals.push(`brand_mismatch:${brand}`);
34
+ }
35
+ }
36
+
37
+ // 3. Urgency language
38
+ const urgencyPhrases = [
39
+ "your account has been", "verify immediately", "suspended",
40
+ "unusual activity", "click here now", "act now",
41
+ "confirm your identity", "limited time", "expires soon"
42
+ ];
43
+ for (const phrase of urgencyPhrases) {
44
+ if (bodyText.includes(phrase)) {
45
+ signals.push("urgency_language");
46
+ break;
47
+ }
48
+ }
49
+
50
+ // 4. Hidden iframes
51
+ const iframes = document.querySelectorAll("iframe");
52
+ iframes.forEach(iframe => {
53
+ const style = window.getComputedStyle(iframe);
54
+ const w = parseInt(style.width) || iframe.width;
55
+ const h = parseInt(style.height) || iframe.height;
56
+ if (style.display === "none" || style.visibility === "hidden" ||
57
+ (w <= 1 && h <= 1)) {
58
+ signals.push("hidden_iframe");
59
+ }
60
+ });
61
+
62
+ return {
63
+ title,
64
+ snippet: bodyText.substring(0, 500),
65
+ signals,
66
+ };
67
+ }
68
+
69
+ // Send signals to background.js
70
+ try {
71
+ const pageData = detectPageSignals();
72
+ chrome.runtime.sendMessage({
73
+ type: "page_signals",
74
+ url: window.location.href,
75
+ ...pageData,
76
+ });
77
+ } catch (e) {
78
+ // Extension context may be invalidated
79
+ }
80
+
81
+ // ── Feedback Banner Injection ──────────────────────────────────
82
+ // Listen for messages from background.js to inject banner
83
+ chrome.runtime.onMessage.addListener((msg, sender, sendResponse) => {
84
+ if (msg.type === "inject_feedback_banner") {
85
+ injectFeedbackBanner(msg.verdict, msg.confidence, msg.urlHash, msg.tier);
86
+ sendResponse({ success: true });
87
+ }
88
+ });
89
+
90
+ function injectFeedbackBanner(verdict, confidence, urlHash, tier) {
91
+ // Don't inject if already present
92
+ if (document.getElementById("phishguard-feedback-banner")) return;
93
+
94
+ const isPhishing = verdict === "phishing";
95
+ const confPct = Math.round(confidence * 100);
96
+ const tierText = `Tier ${tier}`;
97
+
98
+ const banner = document.createElement("div");
99
+ banner.id = "phishguard-feedback-banner";
100
+ banner.style.cssText = `
101
+ position: fixed; top: 0; left: 0; right: 0; z-index: 2147483647;
102
+ background: ${isPhishing ? "linear-gradient(135deg, #1a0000, #3a0000)" : "linear-gradient(135deg, #001a00, #003a00)"};
103
+ color: white; padding: 10px 20px;
104
+ display: flex; align-items: center; gap: 12px;
105
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
106
+ font-size: 14px; box-shadow: 0 4px 20px rgba(0,0,0,0.5);
107
+ border-bottom: 2px solid ${isPhishing ? "#ef4444" : "#22c55e"};
108
+ `;
109
+
110
+ const icon = isPhishing ? "πŸ›‘οΈ" : "βœ…";
111
+ const statusText = isPhishing ? "PhishGuard flagged this page" : "PhishGuard: Page looks safe";
112
+
113
+ banner.innerHTML = `
114
+ <span style="font-size: 18px">${icon}</span>
115
+ <span style="flex: 1">${statusText} Β· ${confPct}% Β· ${tierText}</span>
116
+ <button id="pg-correct" style="
117
+ background: rgba(34,197,94,0.2); border: 1px solid #22c55e; color: #22c55e;
118
+ padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 13px;
119
+ font-weight: 600; transition: all 0.2s;
120
+ ">πŸ‘ Correct</button>
121
+ <button id="pg-wrong" style="
122
+ background: rgba(239,68,68,0.2); border: 1px solid #ef4444; color: #ef4444;
123
+ padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 13px;
124
+ font-weight: 600; transition: all 0.2s;
125
+ ">πŸ‘Ž Wrong</button>
126
+ ${isPhishing ? `<button id="pg-proceed" style="
127
+ background: transparent; border: 1px solid rgba(255,255,255,0.2); color: #999;
128
+ padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 12px;
129
+ transition: all 0.2s;
130
+ ">Proceed Anyway</button>` : ""}
131
+ <button id="pg-close" style="
132
+ background: none; border: none; color: #666; cursor: pointer;
133
+ font-size: 18px; padding: 0 4px;
134
+ ">Γ—</button>
135
+ `;
136
+
137
+ document.body.prepend(banner);
138
+ document.body.style.marginTop = (banner.offsetHeight) + "px";
139
+
140
+ // Button handlers
141
+ document.getElementById("pg-correct")?.addEventListener("click", () => {
142
+ submitBannerFeedback(urlHash, "correct", banner);
143
+ });
144
+
145
+ document.getElementById("pg-wrong")?.addEventListener("click", () => {
146
+ submitBannerFeedback(urlHash, "incorrect", banner);
147
+ });
148
+
149
+ document.getElementById("pg-proceed")?.addEventListener("click", () => {
150
+ chrome.runtime.sendMessage({ type: "whitelist_url", url: window.location.href });
151
+ removeBanner(banner);
152
+ });
153
+
154
+ document.getElementById("pg-close")?.addEventListener("click", () => {
155
+ removeBanner(banner);
156
+ });
157
+ }
158
+
159
+ function submitBannerFeedback(urlHash, feedback, banner) {
160
+ chrome.runtime.sendMessage({
161
+ type: "submit_feedback",
162
+ url_hash: urlHash,
163
+ feedback: feedback,
164
+ }, (response) => {
165
+ if (response?.success) {
166
+ banner.innerHTML = `
167
+ <span style="font-size: 18px">βœ…</span>
168
+ <span style="flex: 1; color: #22c55e">Thanks! Your feedback helps improve PhishGuard</span>
169
+ `;
170
+ setTimeout(() => removeBanner(banner), 3000);
171
+ }
172
+ });
173
+ }
174
+
175
+ function removeBanner(banner) {
176
+ document.body.style.marginTop = "";
177
+ banner.remove();
178
+ }
179
+
180
+ })();
data_collector.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - data_collector.py
3
+ # Downloads all training data from public HTTP endpoints.
4
+ # No API keys required.
5
+ #
6
+ # Datasets:
7
+ # 1. PhishTank (bz2 JSON β†’ phishing URLs)
8
+ # 2. TRANCO Top-10K (zip CSV β†’ legitimate domains)
9
+ # 3. Kaggle GitHub mirror (CSV β†’ pre-extracted features)
10
+ # ============================================================
11
+
12
+ from __future__ import annotations
13
+
14
+ import bz2
15
+ import csv
16
+ import io
17
+ import json
18
+ import zipfile
19
+ import hashlib
20
+ import logging
21
+ from pathlib import Path
22
+ from typing import List, Tuple, Optional
23
+
24
+ import requests
25
+ import pandas as pd
26
+ from sklearn.model_selection import train_test_split
27
+
28
+ logger = logging.getLogger("phishguard.data_collector")
29
+
30
+ # ── Data directory ────────────────────────────────────────────────────
31
+ DATA_DIR = Path(__file__).parent / "data"
32
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
33
+
34
+ # ── Public URLs (no API keys) ────────────────────────────────────────
35
+ PHISHTANK_URL = "http://data.phishtank.com/data/online-valid.json.bz2"
36
+ TRANCO_URL = "https://tranco-list.eu/top-1m.csv.zip"
37
+ KAGGLE_PRIMARY = "https://raw.githubusercontent.com/GregaVrbancic/Phishing-Dataset/master/dataset_full.csv"
38
+ KAGGLE_BACKUP = "https://raw.githubusercontent.com/datasets/phishing-websites/master/data.csv"
39
+
40
+ HEADERS = {
41
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
42
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
43
+ "Chrome/120.0.0.0 Safari/537.36"
44
+ }
45
+
46
+
47
+ def download_phishtank(max_urls: int = 30000) -> List[str]:
48
+ """
49
+ Download phishing URLs from PhishTank public feed.
50
+ Fetches bz2 β†’ decompresses β†’ parses JSON β†’ filters verified+online.
51
+
52
+ Returns list of verified phishing URLs (up to max_urls).
53
+ """
54
+ logger.info("Downloading PhishTank data...")
55
+ phish_cache = DATA_DIR / "phishing_urls.txt"
56
+
57
+ # Use cache if recent
58
+ if phish_cache.exists() and phish_cache.stat().st_size > 1000:
59
+ urls = phish_cache.read_text().strip().splitlines()
60
+ if len(urls) >= 100:
61
+ logger.info(f"Using cached PhishTank data: {len(urls)} URLs")
62
+ return urls[:max_urls]
63
+
64
+ try:
65
+ resp = requests.get(PHISHTANK_URL, headers=HEADERS, timeout=120, stream=True)
66
+ resp.raise_for_status()
67
+
68
+ # Decompress bz2
69
+ raw_data = bz2.decompress(resp.content)
70
+ records = json.loads(raw_data)
71
+
72
+ # Filter: verified=True AND online (verification_time present)
73
+ urls: List[str] = []
74
+ for record in records:
75
+ if not isinstance(record, dict):
76
+ continue
77
+ url = record.get("url", "").strip()
78
+ verified = record.get("verified", "no")
79
+ online = record.get("online", "no")
80
+
81
+ is_verified = verified in (True, "yes", "true", "True", "1", 1)
82
+ is_online = online in (True, "yes", "true", "True", "1", 1)
83
+
84
+ if url and is_verified and is_online:
85
+ urls.append(url)
86
+ if len(urls) >= max_urls:
87
+ break
88
+
89
+ logger.info(f"PhishTank: {len(urls)} verified+online URLs extracted")
90
+
91
+ # Cache to disk
92
+ phish_cache.write_text("\n".join(urls))
93
+ return urls
94
+
95
+ except Exception as e:
96
+ logger.warning(f"PhishTank download failed: {e}")
97
+ # Fallback: try to use cached data
98
+ if phish_cache.exists():
99
+ urls = phish_cache.read_text().strip().splitlines()
100
+ logger.info(f"Using fallback cached data: {len(urls)} URLs")
101
+ return urls[:max_urls]
102
+
103
+ # Generate synthetic phishing-like URLs for training
104
+ logger.warning("Generating synthetic phishing URLs as fallback")
105
+ return _generate_synthetic_phishing(500)
106
+
107
+
108
+ def _generate_synthetic_phishing(count: int) -> List[str]:
109
+ """Generate synthetic phishing URLs for training when real data unavailable."""
110
+ import random
111
+ brands = ["paypal", "google", "apple", "microsoft", "amazon", "netflix",
112
+ "facebook", "chase", "wellsfargo", "bankofamerica"]
113
+ tlds = [".xyz", ".tk", ".ml", ".ga", ".cf", ".gq", ".pw", ".top", ".click"]
114
+ keywords = ["login", "verify", "secure", "update", "account", "signin",
115
+ "reset", "confirm", "suspend", "banking", "alert", "password"]
116
+ urls: List[str] = []
117
+ for _ in range(count):
118
+ brand = random.choice(brands)
119
+ tld = random.choice(tlds)
120
+ kw = random.choice(keywords)
121
+ sep = random.choice(["-", ".", ""])
122
+ prefix = random.choice(["http://", "https://"])
123
+ sub = random.choice(["", "www.", "secure.", "login.", "m."])
124
+ urls.append(f"{prefix}{sub}{brand}{sep}{kw}{tld}/{kw}/index.html")
125
+ return urls
126
+
127
+
128
+ def download_tranco(n: int = 10000) -> List[str]:
129
+ """
130
+ Download TRANCO Top-1M list, return top-N domains as https:// URLs.
131
+
132
+ Fetches zip β†’ extracts CSV β†’ takes column 2 (domain) β†’ top N rows.
133
+ """
134
+ logger.info(f"Downloading TRANCO top-{n} domains...")
135
+ legit_cache = DATA_DIR / "legitimate_urls.txt"
136
+
137
+ # Use cache if present
138
+ if legit_cache.exists() and legit_cache.stat().st_size > 1000:
139
+ urls = legit_cache.read_text().strip().splitlines()
140
+ if len(urls) >= min(n, 100):
141
+ logger.info(f"Using cached TRANCO data: {len(urls)} domains")
142
+ return urls[:n]
143
+
144
+ try:
145
+ resp = requests.get(TRANCO_URL, headers=HEADERS, timeout=60)
146
+ resp.raise_for_status()
147
+
148
+ # Extract CSV from zip
149
+ with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
150
+ csv_name = zf.namelist()[0]
151
+ csv_data = zf.read(csv_name).decode("utf-8")
152
+
153
+ # Parse: format is "rank,domain" per line
154
+ urls: List[str] = []
155
+ for line in csv_data.strip().splitlines():
156
+ parts = line.split(",")
157
+ if len(parts) >= 2:
158
+ domain = parts[1].strip()
159
+ if domain:
160
+ urls.append(f"https://{domain}")
161
+ if len(urls) >= n:
162
+ break
163
+
164
+ logger.info(f"TRANCO: {len(urls)} legitimate domains extracted")
165
+
166
+ # Cache to disk
167
+ legit_cache.write_text("\n".join(urls))
168
+ return urls
169
+
170
+ except Exception as e:
171
+ logger.warning(f"TRANCO download failed: {e}")
172
+ # Fallback: use cached data or generate synthetic
173
+ if legit_cache.exists():
174
+ urls = legit_cache.read_text().strip().splitlines()
175
+ return urls[:n]
176
+
177
+ logger.warning("Generating synthetic legitimate URLs as fallback")
178
+ return _generate_synthetic_legitimate(n)
179
+
180
+
181
+ def _generate_synthetic_legitimate(count: int) -> List[str]:
182
+ """Generate legitimate-looking URLs as fallback."""
183
+ top_domains = [
184
+ "google.com", "youtube.com", "facebook.com", "amazon.com",
185
+ "wikipedia.org", "twitter.com", "instagram.com", "linkedin.com",
186
+ "microsoft.com", "apple.com", "github.com", "stackoverflow.com",
187
+ "reddit.com", "netflix.com", "paypal.com", "yahoo.com", "bing.com",
188
+ "adobe.com", "dropbox.com", "zoom.us", "slack.com", "spotify.com",
189
+ "twitch.tv", "ebay.com", "walmart.com", "target.com", "cnn.com",
190
+ "bbc.com", "nytimes.com", "medium.com",
191
+ ]
192
+ urls = [f"https://{d}" for d in top_domains]
193
+ # Pad with numbered subpages
194
+ while len(urls) < count:
195
+ d = top_domains[len(urls) % len(top_domains)]
196
+ urls.append(f"https://{d}/page/{len(urls)}")
197
+ return urls[:count]
198
+
199
+
200
+ def download_kaggle_mirror() -> pd.DataFrame:
201
+ """
202
+ Download pre-extracted URL features from Kaggle GitHub mirror.
203
+ Falls back to backup URL if primary fails.
204
+
205
+ Returns DataFrame with features and CLASS_LABEL column.
206
+ """
207
+ logger.info("Downloading Kaggle URL features dataset...")
208
+ kaggle_cache = DATA_DIR / "kaggle_features.csv"
209
+
210
+ if kaggle_cache.exists() and kaggle_cache.stat().st_size > 1000:
211
+ logger.info("Using cached Kaggle features")
212
+ return pd.read_csv(kaggle_cache)
213
+
214
+ for url in [KAGGLE_PRIMARY, KAGGLE_BACKUP]:
215
+ try:
216
+ resp = requests.get(url, headers=HEADERS, timeout=60)
217
+ resp.raise_for_status()
218
+ df = pd.read_csv(io.StringIO(resp.text))
219
+
220
+ # Standardize label column name
221
+ label_candidates = ["CLASS_LABEL", "class_label", "Result", "result", "label"]
222
+ for col in label_candidates:
223
+ if col in df.columns:
224
+ df = df.rename(columns={col: "CLASS_LABEL"})
225
+ break
226
+
227
+ if "CLASS_LABEL" not in df.columns:
228
+ # Try last column
229
+ df = df.rename(columns={df.columns[-1]: "CLASS_LABEL"})
230
+
231
+ # Normalize labels to 0/1
232
+ if df["CLASS_LABEL"].dtype == object:
233
+ df["CLASS_LABEL"] = df["CLASS_LABEL"].map(
234
+ {"legitimate": 0, "phishing": 1, "safe": 0}
235
+ ).fillna(0).astype(int)
236
+ else:
237
+ # Handle -1 as legitimate (common in some datasets)
238
+ df["CLASS_LABEL"] = df["CLASS_LABEL"].apply(
239
+ lambda x: 0 if x <= 0 else 1
240
+ )
241
+
242
+ # Cache
243
+ df.to_csv(kaggle_cache, index=False)
244
+ logger.info(f"Kaggle features: {len(df)} rows, {len(df.columns)} columns")
245
+ return df
246
+
247
+ except Exception as e:
248
+ logger.warning(f"Kaggle mirror {url} failed: {e}")
249
+ continue
250
+
251
+ logger.error("All Kaggle mirrors failed")
252
+ return pd.DataFrame()
253
+
254
+
255
+ def merge_datasets(
256
+ phish_urls: List[str],
257
+ legit_urls: List[str],
258
+ test_size: float = 0.15,
259
+ val_size: float = 0.15,
260
+ ) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]], List[Tuple[str, int]]]:
261
+ """
262
+ Merge phishing + legitimate URLs, return stratified 70/15/15 split.
263
+
264
+ Returns (train, val, test) where each is List[(url, label)].
265
+ Label: 1 = phishing, 0 = legitimate.
266
+ """
267
+ # Deduplicate
268
+ phish_set = set(phish_urls)
269
+ legit_set = set(legit_urls) - phish_set # Ensure no URL in both sets
270
+
271
+ all_data = [(url, 1) for url in phish_set] + [(url, 0) for url in legit_set]
272
+ urls = [d[0] for d in all_data]
273
+ labels = [d[1] for d in all_data]
274
+
275
+ # First split: train+val vs test
276
+ train_val_urls, test_urls, train_val_labels, test_labels = train_test_split(
277
+ urls, labels,
278
+ test_size=test_size,
279
+ stratify=labels,
280
+ random_state=42,
281
+ )
282
+
283
+ # Second split: train vs val
284
+ relative_val = val_size / (1 - test_size)
285
+ train_urls, val_urls, train_labels, val_labels = train_test_split(
286
+ train_val_urls, train_val_labels,
287
+ test_size=relative_val,
288
+ stratify=train_val_labels,
289
+ random_state=42,
290
+ )
291
+
292
+ train = list(zip(train_urls, train_labels))
293
+ val = list(zip(val_urls, val_labels))
294
+ test = list(zip(test_urls, test_labels))
295
+
296
+ logger.info(f"Dataset split: train={len(train)}, val={len(val)}, test={len(test)}")
297
+ return train, val, test
298
+
299
+
300
+ def save_url_lists(
301
+ phish_urls: List[str],
302
+ legit_urls: List[str],
303
+ phish_path: Optional[Path] = None,
304
+ legit_path: Optional[Path] = None,
305
+ ) -> None:
306
+ """Save URL lists to text files."""
307
+ phish_path = phish_path or DATA_DIR / "phishing_urls.txt"
308
+ legit_path = legit_path or DATA_DIR / "legitimate_urls.txt"
309
+
310
+ phish_path.write_text("\n".join(phish_urls))
311
+ legit_path.write_text("\n".join(legit_urls))
312
+ logger.info(f"Saved {len(phish_urls)} phishing URLs to {phish_path}")
313
+ logger.info(f"Saved {len(legit_urls)} legitimate URLs to {legit_path}")
314
+
315
+
316
+ def url_hash(url: str) -> str:
317
+ """SHA256 hash of a URL (for dedup and privacy)."""
318
+ return hashlib.sha256(url.encode("utf-8")).hexdigest()
319
+
320
+
321
+ # ── Entry point ──────────────────────────────────────────────────────
322
+ def main() -> None:
323
+ logging.basicConfig(
324
+ level=logging.INFO,
325
+ format="%(asctime)s | %(levelname)-7s | %(message)s",
326
+ )
327
+
328
+ print("=" * 60)
329
+ print("PhishGuard AI β€” Data Collection")
330
+ print("=" * 60)
331
+
332
+ # 1. PhishTank
333
+ phish_urls = download_phishtank()
334
+ print(f"\nβœ… PhishTank: {len(phish_urls)} phishing URLs")
335
+
336
+ # 2. TRANCO
337
+ legit_urls = download_tranco(n=10000)
338
+ print(f"βœ… TRANCO: {len(legit_urls)} legitimate URLs")
339
+
340
+ # 3. Kaggle features
341
+ kaggle_df = download_kaggle_mirror()
342
+ if not kaggle_df.empty:
343
+ phish_count = (kaggle_df["CLASS_LABEL"] == 1).sum()
344
+ legit_count = (kaggle_df["CLASS_LABEL"] == 0).sum()
345
+ print(f"βœ… Kaggle: {len(kaggle_df)} rows ({phish_count} phishing, {legit_count} legit)")
346
+ else:
347
+ print("⚠️ Kaggle: download failed (will use PhishTank + TRANCO only)")
348
+
349
+ # 4. Save URL lists
350
+ save_url_lists(phish_urls, legit_urls)
351
+
352
+ # 5. Merge and split
353
+ train, val, test = merge_datasets(phish_urls, legit_urls)
354
+ print(f"\nπŸ“Š Dataset splits:")
355
+ print(f" Train: {len(train)} ({sum(1 for _,l in train if l==1)} phish / {sum(1 for _,l in train if l==0)} legit)")
356
+ print(f" Val: {len(val)} ({sum(1 for _,l in val if l==1)} phish / {sum(1 for _,l in val if l==0)} legit)")
357
+ print(f" Test: {len(test)} ({sum(1 for _,l in test if l==1)} phish / {sum(1 for _,l in test if l==0)} legit)")
358
+
359
+ print(f"\nβœ… All data saved to {DATA_DIR}")
360
+ print("=" * 60)
361
+
362
+
363
+ if __name__ == "__main__":
364
+ main()
domain_graph_builder.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - gnn/domain_graph_builder.py
3
+ # Builds graph representations for GNN inference + training.
4
+ #
5
+ # Node features (12-dim per URL):
6
+ # [url_len_norm, domain_len_norm, subdomain_count_norm,
7
+ # shannon_entropy_norm, digit_ratio, hyphen_count_norm,
8
+ # phishing_keyword_hits_norm, suspicious_tld_binary,
9
+ # ip_as_hostname_binary, has_https_binary,
10
+ # path_depth_norm, query_string_len_norm]
11
+ #
12
+ # Edges: shared suspicious TLD + shared IP (async DNS)
13
+ # ============================================================
14
+
15
+ from __future__ import annotations
16
+
17
+ import re
18
+ import math
19
+ import asyncio
20
+ import logging
21
+ import socket
22
+ from typing import Dict, List, Optional, Tuple
23
+ from urllib.parse import urlparse
24
+
25
+ import numpy as np
26
+
27
+ logger = logging.getLogger("phishguard.gnn.graph_builder")
28
+
29
+ # ── Constants ────────────────────────────────────────────────────────
30
+ SUSPICIOUS_TLDS = frozenset({
31
+ ".xyz", ".tk", ".ml", ".ga", ".cf",
32
+ ".gq", ".pw", ".top", ".click",
33
+ })
34
+
35
+ PHISHING_KEYWORDS = frozenset({
36
+ "login", "verify", "secure", "update", "account",
37
+ "banking", "signin", "reset", "confirm", "suspend",
38
+ "webscr", "cmd", "payment", "alert",
39
+ })
40
+
41
+ _re_ip = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")
42
+
43
+
44
+ class DomainGraphBuilder:
45
+ """
46
+ Builds PyTorch Geometric Data objects from URL lists.
47
+ Each URL becomes a node with 12-dim feature vector.
48
+ Edges are created from shared IP addresses and shared TLDs.
49
+ """
50
+
51
+ def __init__(self) -> None:
52
+ self._re_ip = _re_ip
53
+
54
+ def extract_node_features(self, url: str) -> np.ndarray:
55
+ """
56
+ Extract 12-dim feature vector from a URL.
57
+
58
+ Returns np.ndarray of shape (12,) with values in [0, 1].
59
+ """
60
+ try:
61
+ parsed = urlparse(url if "://" in url else f"http://{url}")
62
+ except Exception:
63
+ return np.zeros(12, dtype=np.float32)
64
+
65
+ hostname: str = (parsed.hostname or "").lower()
66
+ path: str = parsed.path or ""
67
+ query: str = parsed.query or ""
68
+ scheme: str = parsed.scheme or ""
69
+
70
+ # 1. url_len_norm (normalized by 500)
71
+ url_len_norm = min(len(url) / 500.0, 1.0)
72
+
73
+ # 2. domain_len_norm (normalized by 100)
74
+ domain_len_norm = min(len(hostname) / 100.0, 1.0)
75
+
76
+ # 3. subdomain_count_norm
77
+ parts = hostname.split(".")
78
+ subdomain_count = max(0, len(parts) - 2)
79
+ subdomain_count_norm = min(subdomain_count / 10.0, 1.0)
80
+
81
+ # 4. shannon_entropy_norm (normalized by 5.0)
82
+ entropy = self._shannon_entropy(hostname)
83
+ shannon_entropy_norm = min(entropy / 5.0, 1.0)
84
+
85
+ # 5. digit_ratio
86
+ digit_ratio = 0.0
87
+ if hostname:
88
+ digits = sum(1 for c in hostname if c.isdigit())
89
+ digit_ratio = digits / len(hostname)
90
+
91
+ # 6. hyphen_count_norm
92
+ hyphen_count = hostname.count("-")
93
+ hyphen_count_norm = min(hyphen_count / 10.0, 1.0)
94
+
95
+ # 7. phishing_keyword_hits_norm
96
+ url_lower = url.lower()
97
+ keyword_hits = sum(1 for kw in PHISHING_KEYWORDS if kw in url_lower)
98
+ phishing_keyword_hits_norm = min(keyword_hits / 5.0, 1.0)
99
+
100
+ # 8. suspicious_tld_binary
101
+ suspicious_tld_binary = 0.0
102
+ for tld in SUSPICIOUS_TLDS:
103
+ if hostname.endswith(tld):
104
+ suspicious_tld_binary = 1.0
105
+ break
106
+
107
+ # 9. ip_as_hostname_binary
108
+ ip_as_hostname_binary = 1.0 if self._re_ip.match(hostname) else 0.0
109
+
110
+ # 10. has_https_binary
111
+ has_https_binary = 1.0 if scheme == "https" else 0.0
112
+
113
+ # 11. path_depth_norm
114
+ path_segments = [s for s in path.split("/") if s]
115
+ path_depth_norm = min(len(path_segments) / 10.0, 1.0)
116
+
117
+ # 12. query_string_len_norm
118
+ query_string_len_norm = min(len(query) / 500.0, 1.0)
119
+
120
+ features = np.array([
121
+ url_len_norm,
122
+ domain_len_norm,
123
+ subdomain_count_norm,
124
+ shannon_entropy_norm,
125
+ digit_ratio,
126
+ hyphen_count_norm,
127
+ phishing_keyword_hits_norm,
128
+ suspicious_tld_binary,
129
+ ip_as_hostname_binary,
130
+ has_https_binary,
131
+ path_depth_norm,
132
+ query_string_len_norm,
133
+ ], dtype=np.float32)
134
+
135
+ return features
136
+
137
+ def _shannon_entropy(self, s: str) -> float:
138
+ """Compute Shannon entropy of a string."""
139
+ if not s:
140
+ return 0.0
141
+ length = len(s)
142
+ freq: Dict[str, int] = {}
143
+ for c in s:
144
+ freq[c] = freq.get(c, 0) + 1
145
+ return -sum(
146
+ (count / length) * math.log2(count / length)
147
+ for count in freq.values()
148
+ if count > 0
149
+ )
150
+
151
+ async def _resolve_ips(self, domains: List[str]) -> Dict[str, str]:
152
+ """
153
+ Async DNS resolution for a list of domains.
154
+ Returns dict mapping domain β†’ IP address.
155
+ """
156
+ results: Dict[str, str] = {}
157
+ loop = asyncio.get_event_loop()
158
+
159
+ async def resolve_one(domain: str) -> Tuple[str, str]:
160
+ try:
161
+ ip = await asyncio.wait_for(
162
+ loop.run_in_executor(None, socket.gethostbyname, domain),
163
+ timeout=2.0,
164
+ )
165
+ return domain, ip
166
+ except Exception:
167
+ return domain, ""
168
+
169
+ tasks = [resolve_one(d) for d in domains]
170
+ resolved = await asyncio.gather(*tasks, return_exceptions=True)
171
+ for item in resolved:
172
+ if isinstance(item, tuple):
173
+ domain, ip = item
174
+ if ip:
175
+ results[domain] = ip
176
+ return results
177
+
178
+ def _add_shared_ip_edges(
179
+ self, domains: List[str], ips: Dict[str, str]
180
+ ) -> List[Tuple[int, int]]:
181
+ """
182
+ Create edges between nodes that share the same IP address.
183
+ Returns list of (src, dst) index pairs.
184
+ """
185
+ edges: List[Tuple[int, int]] = []
186
+ # Group domain indices by IP
187
+ ip_to_indices: Dict[str, List[int]] = {}
188
+ for idx, domain in enumerate(domains):
189
+ ip = ips.get(domain, "")
190
+ if ip:
191
+ ip_to_indices.setdefault(ip, []).append(idx)
192
+
193
+ # Create edges between all nodes sharing an IP
194
+ for ip, indices in ip_to_indices.items():
195
+ for i in range(len(indices)):
196
+ for j in range(i + 1, len(indices)):
197
+ edges.append((indices[i], indices[j]))
198
+ edges.append((indices[j], indices[i])) # bidirectional
199
+
200
+ return edges
201
+
202
+ def _add_shared_tld_edges(self, domains: List[str]) -> List[Tuple[int, int]]:
203
+ """
204
+ Create edges between nodes that share the same suspicious TLD.
205
+ """
206
+ edges: List[Tuple[int, int]] = []
207
+ tld_to_indices: Dict[str, List[int]] = {}
208
+
209
+ for idx, domain in enumerate(domains):
210
+ for tld in SUSPICIOUS_TLDS:
211
+ if domain.endswith(tld):
212
+ tld_to_indices.setdefault(tld, []).append(idx)
213
+ break
214
+
215
+ for tld, indices in tld_to_indices.items():
216
+ for i in range(len(indices)):
217
+ for j in range(i + 1, len(indices)):
218
+ edges.append((indices[i], indices[j]))
219
+ edges.append((indices[j], indices[i]))
220
+
221
+ return edges
222
+
223
+ def build_graph(self, urls: List[str], resolve_dns: bool = False) -> dict:
224
+ """
225
+ Build a graph dict from a list of URLs.
226
+
227
+ Returns dict with:
228
+ - features: np.ndarray of shape (N, 12)
229
+ - edges: List of (src, dst) pairs
230
+ - node_count: int
231
+ - edge_count: int
232
+ - domains: List[str]
233
+ """
234
+ if not urls:
235
+ return {
236
+ "features": np.zeros((1, 12), dtype=np.float32),
237
+ "edges": [],
238
+ "node_count": 0,
239
+ "edge_count": 0,
240
+ "domains": [],
241
+ }
242
+
243
+ # Extract features for each URL
244
+ features = np.array(
245
+ [self.extract_node_features(url) for url in urls],
246
+ dtype=np.float32,
247
+ )
248
+
249
+ # Extract domains
250
+ domains: List[str] = []
251
+ for url in urls:
252
+ try:
253
+ parsed = urlparse(url if "://" in url else f"http://{url}")
254
+ domains.append((parsed.hostname or "").lower())
255
+ except Exception:
256
+ domains.append("")
257
+
258
+ # Build edges from shared TLDs (synchronous, fast)
259
+ edges = self._add_shared_tld_edges(domains)
260
+
261
+ # Optionally resolve DNS for shared IP edges
262
+ if resolve_dns and len(domains) > 1:
263
+ try:
264
+ loop = asyncio.get_event_loop()
265
+ if loop.is_running():
266
+ # Already in async context
267
+ pass
268
+ else:
269
+ ips = loop.run_until_complete(self._resolve_ips(domains))
270
+ edges.extend(self._add_shared_ip_edges(domains, ips))
271
+ except RuntimeError:
272
+ pass # Cannot resolve in this context
273
+
274
+ return {
275
+ "features": features,
276
+ "edges": edges,
277
+ "node_count": len(urls),
278
+ "edge_count": len(edges),
279
+ "domains": domains,
280
+ }
281
+
282
+ def build_single_node_graph(self, url: str) -> dict:
283
+ """
284
+ Build a single-node graph for MLP fallback path.
285
+ Used when a graph has fewer than 2 nodes.
286
+ """
287
+ features = self.extract_node_features(url).reshape(1, -1)
288
+ return {
289
+ "features": features,
290
+ "edges": [],
291
+ "node_count": 1,
292
+ "edge_count": 0,
293
+ "domains": [url],
294
+ }
295
+
296
+
297
+ # ── Legacy compatibility wrapper ─────────────────────────────────────
298
+ _builder = DomainGraphBuilder()
299
+
300
+
301
+ def build_domain_graph(urls: List[str]) -> dict:
302
+ """Legacy wrapper for backward compatibility."""
303
+ return _builder.build_graph(urls)
email_analyzer.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - email_analyzer.py
3
+ # Analyzes raw emails for phishing indicators.
4
+ # Checks: sender authentication (SPF/DKIM/DMARC),
5
+ # brand spoofing, urgency language, and embedded links.
6
+ #
7
+ # Reuses BERT model from bert_analyzer to avoid duplicate loading.
8
+ # ============================================================
9
+
10
+ import email
11
+ import re
12
+ from email import policy
13
+ from email.parser import BytesParser, Parser
14
+
15
+ # Reuse the NLP analyzer from bert_analyzer
16
+ from bert_analyzer import analyze_text as bert_analyze_text, _ensure_bert_loaded
17
+ import bert_analyzer
18
+
19
+ print("[PhishGuard] Email analyzer initialized (reusing shared NLP)")
20
+
21
+ URGENCY_PATTERNS = [
22
+ r'(act now|immediate action|urgent|verify immediately|account suspended)',
23
+ r'(click here to (verify|confirm|update|restore))',
24
+ r'(your account (will be|has been) (suspended|closed|deactivated))',
25
+ r'(limited time|expires in \d+ hours?)',
26
+ r'(unusual (sign-in|login|activity) detected)',
27
+ r'(confirm your (identity|password|email|account))',
28
+ r'(we noticed (suspicious|unusual|unauthorized))',
29
+ ]
30
+
31
+ BRAND_SPOOFS = [
32
+ 'paypal','amazon','apple','microsoft','google','netflix',
33
+ 'facebook','instagram','linkedin','twitter','chase','wellsfargo',
34
+ 'bankofamerica','citibank','irs','fedex','ups','dhl',
35
+ 'dropbox','docusign','zoom','office365','hdfc','icici','sbi'
36
+ ]
37
+
38
+
39
+ def parse_email_msg(raw):
40
+ """Parse raw email bytes or string into an email.message object."""
41
+ if isinstance(raw, bytes):
42
+ return BytesParser(policy=policy.default).parsebytes(raw)
43
+ return Parser(policy=policy.default).parsestr(raw)
44
+
45
+
46
+ def extract_urls(text: str) -> list:
47
+ """Extract all unique HTTP/HTTPS URLs from text."""
48
+ return list(set(re.findall(r'https?://[^\s<>"\'\\ ]+', text)))
49
+
50
+
51
+ def get_body(msg) -> str:
52
+ """Extract plain text body from email message, falling back to HTML stripped of tags."""
53
+ parts = []
54
+ if msg.is_multipart():
55
+ for part in msg.walk():
56
+ ct = part.get_content_type()
57
+ if ct == 'text/plain':
58
+ try: parts.append(part.get_content())
59
+ except: pass
60
+ elif ct == 'text/html' and not parts:
61
+ try: parts.append(re.sub(r'<[^>]+>', ' ', part.get_content()))
62
+ except: pass
63
+ else:
64
+ try: parts.append(msg.get_content())
65
+ except: pass
66
+ return ' '.join(parts)
67
+
68
+
69
+ def check_sender_auth(msg) -> dict:
70
+ """
71
+ Check email authentication headers:
72
+ - SPF (Sender Policy Framework)
73
+ - DKIM (DomainKeys Identified Mail)
74
+ - DMARC (Domain-based Message Authentication)
75
+ - From/Return-Path domain mismatch
76
+ - Free email provider usage
77
+ """
78
+ auth = msg.get('Authentication-Results', '').lower()
79
+ spf_raw = msg.get('Received-SPF', '').lower()
80
+ spf_pass = 'spf=pass' in auth or 'pass' in spf_raw
81
+ dkim_pass = 'dkim=pass' in auth
82
+ dmarc_pass= 'dmarc=pass'in auth
83
+
84
+ from_addr = msg.get('From', '')
85
+ return_path = msg.get('Return-Path', '')
86
+ from_dom = re.search(r'@([\w.-]+)', from_addr)
87
+ ret_dom = re.search(r'@([\w.-]+)', return_path)
88
+ mismatch = bool(from_dom and ret_dom and
89
+ from_dom.group(1) != ret_dom.group(1))
90
+
91
+ free = {'gmail.com','yahoo.com','hotmail.com','outlook.com','protonmail.com'}
92
+ using_free = (from_dom.group(1).lower() in free) if from_dom else False
93
+
94
+ risk = 0
95
+ if not spf_pass: risk += 25
96
+ if not dkim_pass: risk += 20
97
+ if not dmarc_pass: risk += 15
98
+ if mismatch: risk += 30
99
+ if using_free: risk += 10
100
+
101
+ return {
102
+ "spf_pass": spf_pass, "dkim_pass": dkim_pass,
103
+ "dmarc_pass": dmarc_pass, "domain_mismatch": mismatch,
104
+ "using_free_email": using_free,
105
+ "auth_risk_score": min(risk, 100)
106
+ }
107
+
108
+
109
+ def check_brand_spoofing(subject: str, body: str, sender: str) -> dict:
110
+ """Detect brand names mentioned in email content but not matching sender domain."""
111
+ combined = (subject + ' ' + body + ' ' + sender).lower()
112
+ sender_dom = re.search(r'@([\w.-]+)', sender)
113
+ s_dom = sender_dom.group(1).lower() if sender_dom else ''
114
+ spoofed = [b for b in BRAND_SPOOFS
115
+ if b in combined and b not in s_dom]
116
+ return {
117
+ "brand_spoof_detected": bool(spoofed),
118
+ "spoofed_brands": spoofed
119
+ }
120
+
121
+
122
+ def check_urgency(text: str) -> dict:
123
+ """Detect urgency/pressure language patterns typical of phishing emails."""
124
+ matches = []
125
+ for pat in URGENCY_PATTERNS:
126
+ found = re.findall(pat, text.lower())
127
+ matches.extend(found)
128
+ return {
129
+ "urgency_detected": bool(matches),
130
+ "urgency_matches": [str(m) for m in matches[:5]],
131
+ "urgency_score": min(len(matches) * 15, 60)
132
+ }
133
+
134
+
135
+ def bert_score(text: str) -> float:
136
+ """Run NLP classifier on email text and return phishing probability."""
137
+ if not text.strip():
138
+ return 0.1
139
+ try:
140
+ _ensure_bert_loaded()
141
+ if bert_analyzer._use_bert and bert_analyzer._classifier is not None:
142
+ result = bert_analyzer._classifier(text[:512])[0]
143
+ label = result['label'].upper()
144
+ score = result['score']
145
+ return score if ('SPAM' in label or label == 'LABEL_1') else 1 - score
146
+ else:
147
+ # Use keyword analysis from bert_analyzer
148
+ result = bert_analyze_text("", "", text)
149
+ return result.get("bert_phishing_prob", 0.3)
150
+ except:
151
+ return 0.3
152
+
153
+
154
+ def analyze_email(raw, return_urls: bool = True) -> dict:
155
+ """
156
+ Full phishing analysis of a raw email.
157
+ Pass raw bytes or a string of the full email.
158
+
159
+ Combines: BERT NLP score + sender auth + brand spoofing + urgency detection.
160
+ """
161
+ msg = parse_email_msg(raw)
162
+ subject = msg.get('Subject', '')
163
+ sender = msg.get('From', '')
164
+ body = get_body(msg)
165
+ urls = extract_urls(body)
166
+
167
+ auth = check_sender_auth(msg)
168
+ brand = check_brand_spoofing(subject, body, sender)
169
+ urgency = check_urgency(subject + ' ' + body)
170
+ bert_p = bert_score(subject + '. ' + body[:400])
171
+
172
+ raw_score = (bert_p * 40 +
173
+ auth['auth_risk_score'] * 0.30 +
174
+ urgency['urgency_score'] * 0.20 +
175
+ (30 if brand['brand_spoof_detected'] else 0) * 0.10)
176
+ final = min(raw_score / 100, 1.0)
177
+
178
+ result = {
179
+ "is_phishing": final > 0.60,
180
+ "phishing_probability": round(final, 4),
181
+ "subject": subject,
182
+ "sender": sender,
183
+ "auth_analysis": auth,
184
+ "brand_analysis": brand,
185
+ "urgency_analysis": urgency,
186
+ "bert_score": round(bert_p, 4),
187
+ "extracted_url_count": len(urls),
188
+ }
189
+ if return_urls:
190
+ result["extracted_urls"] = urls[:20]
191
+ return result
feedback_store.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - feedback_store.py
3
+ # Thread-safe feedback storage, retraining trigger, analytics.
4
+ #
5
+ # Storage: feedback_data.jsonl (append-only, one JSON per line)
6
+ # Lock: asyncio.Lock prevents concurrent writes & double-retrain
7
+ # ============================================================
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import json
13
+ import time
14
+ import asyncio
15
+ import shutil
16
+ import logging
17
+ from datetime import datetime, timezone
18
+ from typing import Optional
19
+
20
+ logger = logging.getLogger("phishguard.feedback")
21
+
22
+ _BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23
+ FEEDBACK_FILE = os.path.join(_BASE_DIR, "feedback_data.jsonl")
24
+ STATE_FILE = os.path.join(_BASE_DIR, "retrain_state.json")
25
+
26
+ # ── Async lock for thread-safe writes ────────────────────────────────────────
27
+ _write_lock = asyncio.Lock()
28
+
29
+ # ── Retrain state (persisted to retrain_state.json) ──────────────────────────
30
+ _retrain_state = {
31
+ "model_version": 1,
32
+ "total_feedback": 0,
33
+ "unprocessed_count": 0,
34
+ "phishing_corrections": 0,
35
+ "safe_corrections": 0,
36
+ "last_retrain": None, # ISO 8601 timestamp
37
+ "retrain_history": [], # [{ts, samples, accuracy, version}]
38
+ }
39
+
40
+
41
+ def _load_state():
42
+ """Load persisted retrain state from disk."""
43
+ global _retrain_state
44
+ if os.path.exists(STATE_FILE):
45
+ try:
46
+ with open(STATE_FILE, "r") as f:
47
+ saved = json.load(f)
48
+ _retrain_state.update(saved)
49
+ logger.info(f"[FeedbackStore] State loaded | version={_retrain_state['model_version']} | total={_retrain_state['total_feedback']}")
50
+ except Exception as e:
51
+ logger.warning(f"[FeedbackStore] Could not load state: {e}")
52
+
53
+
54
+ def _save_state():
55
+ """Persist retrain state to disk (atomic write)."""
56
+ try:
57
+ tmp = STATE_FILE + ".tmp"
58
+ with open(tmp, "w") as f:
59
+ json.dump(_retrain_state, f, indent=2, default=str)
60
+ os.replace(tmp, STATE_FILE)
61
+ except Exception as e:
62
+ logger.warning(f"[FeedbackStore] Could not save state: {e}")
63
+
64
+
65
+ # Load state on module import
66
+ _load_state()
67
+
68
+
69
+ # ══════════════════════════════════════════════════════════════════════════════
70
+ # FEEDBACK STORAGE
71
+ # ══════════════════════════════════════════════════════════════════════════════
72
+
73
+ async def append_feedback(
74
+ url: str,
75
+ label: str,
76
+ source: str = "user_feedback",
77
+ original_prediction: Optional[float] = None,
78
+ ) -> dict:
79
+ """
80
+ Thread-safe append of a feedback entry to feedback_data.jsonl.
81
+
82
+ Returns: {"success": True, "feedback_count": N, "unprocessed": M}
83
+ """
84
+ entry = {
85
+ "url": url,
86
+ "label": label, # "phishing" or "safe"
87
+ "timestamp": datetime.now(timezone.utc).isoformat(),
88
+ "source": source,
89
+ "original_prediction": round(original_prediction, 4) if original_prediction is not None else None,
90
+ }
91
+
92
+ async with _write_lock:
93
+ try:
94
+ with open(FEEDBACK_FILE, "a") as f:
95
+ f.write(json.dumps(entry) + "\n")
96
+ except Exception as e:
97
+ logger.error(f"[FeedbackStore] Write failed: {e}")
98
+ return {"success": False, "error": str(e)}
99
+
100
+ # Update in-memory state
101
+ _retrain_state["total_feedback"] += 1
102
+ _retrain_state["unprocessed_count"] += 1
103
+ if label == "phishing":
104
+ _retrain_state["phishing_corrections"] += 1
105
+ elif label == "safe":
106
+ _retrain_state["safe_corrections"] += 1
107
+
108
+ _save_state()
109
+
110
+ logger.info(f"[FeedbackStore] Saved | url={url} | label={label} | total={_retrain_state['total_feedback']}")
111
+
112
+ return {
113
+ "success": True,
114
+ "feedback_count": _retrain_state["total_feedback"],
115
+ "unprocessed": _retrain_state["unprocessed_count"],
116
+ }
117
+
118
+
119
+ def get_unprocessed_count() -> int:
120
+ """Number of feedback entries since last retraining."""
121
+ return _retrain_state["unprocessed_count"]
122
+
123
+
124
+ def get_model_version() -> int:
125
+ """Current model version number."""
126
+ return _retrain_state["model_version"]
127
+
128
+
129
+ def get_stats() -> dict:
130
+ """Return feedback analytics for the /feedback/stats endpoint."""
131
+ return {
132
+ "total_feedback": _retrain_state["total_feedback"],
133
+ "phishing_corrections": _retrain_state["phishing_corrections"],
134
+ "safe_corrections": _retrain_state["safe_corrections"],
135
+ "unprocessed_count": _retrain_state["unprocessed_count"],
136
+ "last_retrain": _retrain_state["last_retrain"],
137
+ "model_version": _retrain_state["model_version"],
138
+ "retrain_history": _retrain_state["retrain_history"][-10:], # last 10
139
+ }
140
+
141
+
142
+ def get_recent_entries(n: int = 50) -> list:
143
+ """Read the last N feedback entries from the JSONL file."""
144
+ if not os.path.exists(FEEDBACK_FILE):
145
+ return []
146
+ try:
147
+ with open(FEEDBACK_FILE, "r") as f:
148
+ lines = f.readlines()
149
+ entries = []
150
+ for line in lines[-(n):]:
151
+ line = line.strip()
152
+ if line:
153
+ entries.append(json.loads(line))
154
+ return entries
155
+ except Exception:
156
+ return []
157
+
158
+
159
+ # ══════════════════════════════════════════════════════════════════════════════
160
+ # RETRAINING PIPELINE
161
+ # ══════════════════════════════════════════════════════════════════════════════
162
+
163
+ RETRAIN_THRESHOLD = 50
164
+ _retrain_running = False
165
+
166
+
167
+ def should_retrain() -> bool:
168
+ """Check if retraining should be triggered."""
169
+ return (
170
+ _retrain_state["unprocessed_count"] >= RETRAIN_THRESHOLD
171
+ and not _retrain_running
172
+ )
173
+
174
+
175
+ def mark_retrain_complete(samples: int, accuracy: float):
176
+ """
177
+ Called after successful retraining.
178
+ Increments model_version, resets unprocessed counter, logs history.
179
+ """
180
+ _retrain_state["model_version"] += 1
181
+ _retrain_state["unprocessed_count"] = 0
182
+ _retrain_state["last_retrain"] = datetime.now(timezone.utc).isoformat()
183
+ _retrain_state["retrain_history"].append({
184
+ "timestamp": _retrain_state["last_retrain"],
185
+ "samples": samples,
186
+ "accuracy": round(accuracy, 4),
187
+ "version": _retrain_state["model_version"],
188
+ })
189
+ # Keep only last 50 history entries
190
+ if len(_retrain_state["retrain_history"]) > 50:
191
+ _retrain_state["retrain_history"] = _retrain_state["retrain_history"][-50:]
192
+ _save_state()
193
+ logger.info(
194
+ f"[FeedbackStore] Retrained on {samples} feedback samples. "
195
+ f"New accuracy: {accuracy:.2%}. Model version: {_retrain_state['model_version']}"
196
+ )
197
+
198
+
199
+ def archive_feedback_file():
200
+ """Move the processed feedback file to a timestamped backup."""
201
+ if os.path.exists(FEEDBACK_FILE):
202
+ archive = FEEDBACK_FILE + f".{int(time.time())}.bak"
203
+ try:
204
+ shutil.move(FEEDBACK_FILE, archive)
205
+ logger.info(f"[FeedbackStore] Archived feedback β†’ {archive}")
206
+ except Exception as e:
207
+ logger.warning(f"[FeedbackStore] Archive failed: {e}")
208
+
209
+
210
+ def load_feedback_entries() -> list:
211
+ """Load ALL entries from the feedback JSONL file."""
212
+ if not os.path.exists(FEEDBACK_FILE):
213
+ return []
214
+ entries = []
215
+ try:
216
+ with open(FEEDBACK_FILE, "r") as f:
217
+ for line in f:
218
+ line = line.strip()
219
+ if line:
220
+ entries.append(json.loads(line))
221
+ except Exception as e:
222
+ logger.error(f"[FeedbackStore] Read failed: {e}")
223
+ return entries
generate_icons.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate PhishGuard extension icons at 16, 48, 128 px using pure Python."""
3
+ import struct, zlib, os
4
+
5
+ def create_png(width, height, pixels):
6
+ """Create a minimal PNG from RGBA pixel data."""
7
+ def chunk(chunk_type, data):
8
+ c = chunk_type + data
9
+ return struct.pack('>I', len(data)) + c + struct.pack('>I', zlib.crc32(c) & 0xffffffff)
10
+
11
+ raw = b''
12
+ for y in range(height):
13
+ raw += b'\x00'
14
+ for x in range(width):
15
+ idx = (y * width + x) * 4
16
+ raw += bytes(pixels[idx:idx+4])
17
+
18
+ return (b'\x89PNG\r\n\x1a\n' +
19
+ chunk(b'IHDR', struct.pack('>IIBBBBB', width, height, 8, 6, 0, 0, 0)) +
20
+ chunk(b'IDAT', zlib.compress(raw)) +
21
+ chunk(b'IEND', b''))
22
+
23
+ def draw_shield_icon(size):
24
+ """Draw a shield icon with checkmark."""
25
+ pixels = [0] * (size * size * 4)
26
+ cx, cy = size / 2, size / 2
27
+ sr, sg, sb = 0x53, 0x4A, 0xB7
28
+ hr, hg, hb = 0x7B, 0x73, 0xD4
29
+
30
+ for y in range(size):
31
+ for x in range(size):
32
+ idx = (y * size + x) * 4
33
+ nx = (x - cx) / (size / 2)
34
+ ny = (y - cy) / (size / 2)
35
+
36
+ in_shield = False
37
+ if ny < -0.05:
38
+ if abs(nx) < 0.75:
39
+ in_shield = True
40
+ elif ny < 0.5:
41
+ w = 0.75 * (1 - ny * 0.8)
42
+ if abs(nx) < w:
43
+ in_shield = True
44
+ else:
45
+ w = 0.75 * max(0, (1.0 - ny) * 1.4)
46
+ if abs(nx) < w:
47
+ in_shield = True
48
+
49
+ if ny < -0.8:
50
+ in_shield = False
51
+
52
+ if in_shield:
53
+ blend = max(0, min(1, 0.5 - nx * 0.3 - ny * 0.2))
54
+ r = int(sr + (hr - sr) * blend)
55
+ g = int(sg + (hg - sg) * blend)
56
+ b = int(sb + (hb - sb) * blend)
57
+ pixels[idx:idx+4] = [r, g, b, 255]
58
+ else:
59
+ pixels[idx:idx+4] = [0, 0, 0, 0]
60
+
61
+ if size >= 32:
62
+ check_points = []
63
+ for t in range(100):
64
+ p = t / 100.0
65
+ if p < 0.4:
66
+ pp = p / 0.4
67
+ px = int(cx + (-0.25 + pp * 0.25) * size * 0.6)
68
+ py = int(cy + (-0.1 + pp * 0.3) * size * 0.6)
69
+ else:
70
+ pp = (p - 0.4) / 0.6
71
+ px = int(cx + (0.0 + pp * 0.35) * size * 0.6)
72
+ py = int(cy + (0.2 - pp * 0.45) * size * 0.6)
73
+ check_points.append((px, py))
74
+
75
+ thickness = max(1, int(size * 0.06))
76
+ for px, py in check_points:
77
+ for dy in range(-thickness, thickness+1):
78
+ for dx in range(-thickness, thickness+1):
79
+ xx, yy = px + dx, py + dy
80
+ if 0 <= xx < size and 0 <= yy < size:
81
+ idx = (yy * size + xx) * 4
82
+ if pixels[idx+3] > 0:
83
+ pixels[idx:idx+4] = [255, 255, 255, 240]
84
+
85
+ return pixels
86
+
87
+ icons_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'extension', 'icons')
88
+ os.makedirs(icons_dir, exist_ok=True)
89
+
90
+ for size in [16, 48, 128]:
91
+ pixels = draw_shield_icon(size)
92
+ png_data = create_png(size, size, pixels)
93
+ path = os.path.join(icons_dir, f'icon{size}.png')
94
+ with open(path, 'wb') as f:
95
+ f.write(png_data)
96
+ print(f"Created {path} ({len(png_data)} bytes)")
97
+
98
+ print("Done! All icons generated.")
gmail_scanner.js ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // gmail_scanner.js
2
+
3
+ console.log("PhishGuard AI: Gmail Scanner loaded.");
4
+
5
+ // Gmail's DOM can change, but commonly the email text is stored in elements with class '.a3s' or '.ii.gt'
6
+ const EMAIL_BODY_SELECTOR = '.a3s, .ii.gt';
7
+
8
+ // Function to inject a visible warning banner into the Gmail UI
9
+ function injectWarningBanner(emailContainer, message) {
10
+ // Prevent duplicate banners if one is already injected
11
+ if (emailContainer.parentNode.querySelector('.phishguard-banner')) {
12
+ return;
13
+ }
14
+
15
+ const banner = document.createElement('div');
16
+ banner.className = 'phishguard-banner';
17
+
18
+ // Styling the banner to look native but urgent, fitting Google's Material Design
19
+ banner.style.backgroundColor = '#fce8e6';
20
+ banner.style.color = '#c5221f';
21
+ banner.style.border = '1px solid #faa59f';
22
+ banner.style.borderRadius = '8px';
23
+ banner.style.padding = '12px 16px';
24
+ // Added margin to ensure it doesn't overlap text awkwardly
25
+ banner.style.margin = '16px auto';
26
+ banner.style.fontFamily = '"Google Sans", Roboto, Arial, sans-serif';
27
+ banner.style.fontSize = '14px';
28
+ banner.style.fontWeight = '500';
29
+ banner.style.lineHeight = '20px';
30
+ banner.style.display = 'flex';
31
+ banner.style.alignItems = 'center';
32
+ banner.style.boxShadow = '0 1px 2px 0 rgba(60,64,67,0.3), 0 1px 3px 1px rgba(60,64,67,0.15)';
33
+
34
+ // Add SVG Icon for warning
35
+ const iconSvg = `
36
+ <svg focusable="false" width="24" height="24" viewBox="0 0 24 24" style="fill: #c5221f; margin-right: 16px; flex-shrink: 0;">
37
+ <path d="M1 21h22L12 2 1 21zm12-3h-2v-2h2v2zm0-4h-2v-4h2v4z"></path>
38
+ </svg>
39
+ `;
40
+
41
+ const textContent = document.createElement('span');
42
+ textContent.innerText = message || '🚨 PhishGuard Warning: This email contains suspicious links.';
43
+
44
+ // Construct the banner
45
+ banner.innerHTML = iconSvg;
46
+ banner.appendChild(textContent);
47
+
48
+ // Insert banner at the top of the email container.
49
+ // By inserting before the email container, we keep it visible at the top of the body.
50
+ if (emailContainer.parentNode) {
51
+ emailContainer.parentNode.insertBefore(banner, emailContainer);
52
+ }
53
+ }
54
+
55
+
56
+ // Helper function to extract all unique URLs from the email body
57
+ function extractUrlsFromBody(emailContainer) {
58
+ const links = emailContainer.querySelectorAll('a[href]');
59
+ const urls = new Set(); // Use a Set to store unique URLs
60
+
61
+ links.forEach(link => {
62
+ const href = link.href;
63
+ // Basic filter to ignore mailto: or javascript: links, and only keep http/https
64
+ if (href && (href.startsWith('http://') || href.startsWith('https://'))) {
65
+ urls.add(href);
66
+ }
67
+ });
68
+
69
+ return Array.from(urls);
70
+ }
71
+
72
+ // Function to safely extract the sender's actual email address
73
+ function extractSenderEmail(emailContainer) {
74
+ // Gmail usually groups each email message into a container block.
75
+ // We traverse up to find a common parent containing both header and body.
76
+ // '.kv' or 'table' or '.adn' is often the parent wrapper for an individual message.
77
+ let messageWrapper = emailContainer.closest('.adn') || emailContainer.closest('.kv') || document;
78
+
79
+ // The sender element usually has the class '.gD' and contains the 'email' attribute.
80
+ const senderElement = messageWrapper.querySelector('.gD');
81
+
82
+ if (senderElement && senderElement.getAttribute('email')) {
83
+ return senderElement.getAttribute('email');
84
+ }
85
+
86
+ // Fallback: Sometimes the email address is enclosed in brackets inside a '.go' element.
87
+ const fallbackSenderElement = messageWrapper.querySelector('.go');
88
+ if (fallbackSenderElement && fallbackSenderElement.innerText) {
89
+ // e.g. "<sender@example.com>" -> matched and extracted
90
+ const match = fallbackSenderElement.innerText.match(/<([^>]+)>/);
91
+ if (match && match[1]) {
92
+ return match[1];
93
+ }
94
+ }
95
+
96
+ return "Unknown Sender";
97
+ }
98
+
99
+ // Function to handle newly opened emails
100
+ function handleEmailOpened(emailContainer) {
101
+ // Prevent re-scanning the same email element
102
+ if (emailContainer.dataset.pgScanned === "true") {
103
+ return;
104
+ }
105
+
106
+ console.log("PhishGuard AI: New email thread opened. Extracting data...");
107
+
108
+ // Mark as scanned
109
+ emailContainer.dataset.pgScanned = "true";
110
+
111
+ // 1. Extract plain text content
112
+ const emailBodyText = emailContainer.innerText;
113
+
114
+ // 2. Extract embedded URLs
115
+ const urls = extractUrlsFromBody(emailContainer);
116
+
117
+ // 3. Extract sender email address
118
+ const sender = extractSenderEmail(emailContainer);
119
+
120
+ // 4. Extract subject line (.hP is a standard class for Gmail's main subject line)
121
+ // Sometimes the subject might be in a '.bog' element. We'll default to 'h2.hP'.
122
+ const subjectElement = document.querySelector('h2.hP') || document.querySelector('.bog');
123
+ const subject = subjectElement ? subjectElement.innerText.trim() : "No Subject Found";
124
+
125
+ // Package the extracted data into a JSON payload
126
+ const emailPayload = {
127
+ sender: sender,
128
+ subject: subject,
129
+ body: emailBodyText,
130
+ urls: urls,
131
+ timestamp: new Date().toISOString()
132
+ };
133
+
134
+ console.log("PhishGuard AI extracted payload:", emailPayload);
135
+
136
+ // Send background message to service worker
137
+ chrome.runtime.sendMessage(
138
+ {
139
+ action: "analyzeEmail",
140
+ data: emailPayload
141
+ },
142
+ (response) => {
143
+ if (chrome.runtime.lastError) {
144
+ console.error("PhishGuard AI: Error communicating with background script:", chrome.runtime.lastError);
145
+ return;
146
+ }
147
+
148
+ console.log("PhishGuard AI Background Analysis Response:", response);
149
+
150
+ // Assume the background script returns `response.analysis` containing `probability` or `isPhishing` flag
151
+ const analysis = response && response.analysis ? response.analysis : {};
152
+
153
+ if (analysis.isPhishing === true || analysis.probability > 0.70) {
154
+ console.warn("PhishGuard AI: High risk email detected! Injecting banner...");
155
+ injectWarningBanner(
156
+ emailContainer,
157
+ '🚨 PhishGuard Warning: This email contains suspicious links and exhibits high-risk phishing behavior.'
158
+ );
159
+ }
160
+ }
161
+ );
162
+ }
163
+
164
+ // Set up a MutationObserver to watch for DOM changes
165
+ // This effectively detects when Gmail dynamically loads an individual email view into the DOM
166
+ const observer = new MutationObserver((mutationsList) => {
167
+ for (const mutation of mutationsList) {
168
+ if (mutation.type === 'childList') {
169
+ mutation.addedNodes.forEach(node => {
170
+ if (node.nodeType === Node.ELEMENT_NODE) {
171
+ // Check if the added node itself is the email body container
172
+ if (node.matches && node.matches(EMAIL_BODY_SELECTOR)) {
173
+ handleEmailOpened(node);
174
+ }
175
+
176
+ // Also search securely within the added node structure for the email body
177
+ if (node.querySelectorAll) {
178
+ const emailBodies = node.querySelectorAll(EMAIL_BODY_SELECTOR);
179
+ emailBodies.forEach(body => handleEmailOpened(body));
180
+ }
181
+ }
182
+ });
183
+ }
184
+ }
185
+ });
186
+
187
+ // Start observing the document body for deeper added nodes (like when navigating between emails)
188
+ observer.observe(document.body, {
189
+ childList: true,
190
+ subtree: true
191
+ });
192
+
193
+ console.log("PhishGuard AI: MutationObserver is listening for email thread opens.");
gnn_inference.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - gnn/gnn_inference.py
3
+ # GNN inference wrapper for main.py.
4
+ # Loads model once at startup, reuses for every request.
5
+ # Supports: predict, hot-reload, incremental_update.
6
+ # ============================================================
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ import sys
12
+ import random
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import List, Optional, Tuple
16
+
17
+ import torch
18
+
19
+ logger = logging.getLogger("phishguard.gnn.inference")
20
+
21
+ # Add parent paths
22
+ _GNN_DIR = Path(__file__).parent
23
+ _BACKEND_DIR = _GNN_DIR.parent
24
+ sys.path.insert(0, str(_GNN_DIR))
25
+ sys.path.insert(0, str(_BACKEND_DIR))
26
+
27
+ from domain_graph_builder import DomainGraphBuilder
28
+ from gnn_model import load_gnn_model, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM
29
+
30
+ if PYGEOM_AVAILABLE:
31
+ from gnn_model import PhishGNN
32
+
33
+ MODEL_PATH = _GNN_DIR / "gnn_weights.pt"
34
+ REPLAY_BUFFER_PATH = _BACKEND_DIR / "data" / "gnn_replay_buffer.pt"
35
+
36
+
37
+ class GNNInference:
38
+ """
39
+ GNN inference wrapper with hot-reload and incremental update support.
40
+ """
41
+
42
+ def __init__(self, weights_path: Optional[Path] = None) -> None:
43
+ self._weights_path = weights_path or MODEL_PATH
44
+ self._model: Optional[torch.nn.Module] = None
45
+ self._builder = DomainGraphBuilder()
46
+ self._loaded = False
47
+
48
+ def load(self, weights_path: Optional[Path] = None) -> bool:
49
+ """Load GNN model from weights file."""
50
+ path = weights_path or self._weights_path
51
+ self._model = load_gnn_model(str(path) if path.exists() else None)
52
+ self._loaded = self._model is not None
53
+ if self._loaded:
54
+ logger.info(f"GNN model loaded from {path}")
55
+ return self._loaded
56
+
57
+ def predict(self, url: str, related_urls: Optional[List[str]] = None) -> float:
58
+ """
59
+ Predict phishing probability for a URL.
60
+ Returns P_gnn ∈ [0,1].
61
+ Falls back to MLP if model unavailable or graph too small.
62
+ """
63
+ if not self._loaded:
64
+ self.load()
65
+
66
+ if self._model is None:
67
+ return 0.5 # Neutral when model unavailable
68
+
69
+ urls = [url] + (related_urls or [])
70
+
71
+ # Single URL β†’ MLP fallback path
72
+ if len(urls) == 1:
73
+ graph = self._builder.build_single_node_graph(url)
74
+ else:
75
+ graph = self._builder.build_graph(urls)
76
+
77
+ x = torch.tensor(graph["features"], dtype=torch.float)
78
+
79
+ edges = graph["edges"]
80
+ if edges and len(edges) > 0:
81
+ edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
82
+ else:
83
+ n = x.size(0)
84
+ edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1)
85
+
86
+ prob = self._model.predict_proba(x, edge_index)
87
+ return round(float(prob), 4)
88
+
89
+ def reload(self, weights_path: Optional[Path] = None) -> bool:
90
+ """Hot-reload model with new weights (no server restart needed)."""
91
+ path = weights_path or self._weights_path
92
+ new_model = load_gnn_model(str(path))
93
+ if new_model is not None:
94
+ self._model = new_model
95
+ self._loaded = True
96
+ logger.info(f"GNN model hot-reloaded from {path}")
97
+ return True
98
+ logger.warning(f"GNN hot-reload failed from {path}")
99
+ return False
100
+
101
+ def incremental_update(
102
+ self,
103
+ samples: List[Tuple[str, int]],
104
+ replay_buffer_path: Optional[Path] = None,
105
+ lr: float = 5e-4,
106
+ epochs: int = 5,
107
+ ) -> Optional[float]:
108
+ """
109
+ Incremental update on feedback samples + replay buffer.
110
+ Returns accuracy_delta or None if failed.
111
+
112
+ samples: list of (url, label) where label is 0 or 1
113
+ """
114
+ if self._model is None:
115
+ logger.warning("GNN not loaded, cannot incrementally update")
116
+ return None
117
+
118
+ if len(samples) < 5:
119
+ logger.warning(f"Too few samples ({len(samples)}) for GNN update")
120
+ return None
121
+
122
+ try:
123
+ import torch.nn.functional as F
124
+
125
+ device = torch.device("cpu")
126
+ model = self._model.to(device)
127
+ builder = DomainGraphBuilder()
128
+
129
+ # Build graphs from new feedback
130
+ new_graphs = []
131
+ CHUNK = 4
132
+ phish = [url for url, label in samples if label == 1]
133
+ legit = [url for url, label in samples if label == 0]
134
+
135
+ for urls, label in [(phish, 1), (legit, 0)]:
136
+ for i in range(0, len(urls), CHUNK):
137
+ chunk = urls[i:i + CHUNK]
138
+ if not chunk:
139
+ continue
140
+ graph = builder.build_graph(chunk)
141
+ x = torch.tensor(graph["features"], dtype=torch.float)
142
+ edges = graph["edges"]
143
+ if edges:
144
+ ei = torch.tensor(edges, dtype=torch.long).t().contiguous()
145
+ else:
146
+ n = x.size(0)
147
+ ei = torch.arange(n).unsqueeze(0).repeat(2, 1)
148
+ new_graphs.append({
149
+ "x": x, "edge_index": ei,
150
+ "y": torch.tensor([float(label)]),
151
+ })
152
+
153
+ # Load replay buffer (20% mix)
154
+ buf_path = replay_buffer_path or REPLAY_BUFFER_PATH
155
+ replay_graphs = []
156
+ if buf_path.exists():
157
+ try:
158
+ all_replay = torch.load(buf_path, map_location="cpu", weights_only=False)
159
+ replay_count = max(1, len(all_replay) // 5) # 20%
160
+ replay_graphs = random.sample(all_replay, min(replay_count, len(all_replay)))
161
+ except Exception as e:
162
+ logger.warning(f"Replay buffer load failed: {e}")
163
+
164
+ # Merge: 80% new + 20% replay
165
+ dataset = new_graphs + replay_graphs
166
+ random.shuffle(dataset)
167
+
168
+ if not dataset:
169
+ return None
170
+
171
+ # Pre-update accuracy
172
+ model.eval()
173
+ pre_correct = 0
174
+ with torch.no_grad():
175
+ for item in dataset:
176
+ out = model(item["x"].to(device), item["edge_index"].to(device))
177
+ pred = 1 if out.squeeze().item() >= 0.5 else 0
178
+ pre_correct += int(pred == int(item["y"].item()))
179
+ pre_acc = pre_correct / len(dataset)
180
+
181
+ # Train
182
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
183
+ model.train()
184
+
185
+ for epoch in range(epochs):
186
+ random.shuffle(dataset)
187
+ total_loss = 0.0
188
+ for item in dataset:
189
+ x = item["x"].to(device)
190
+ ei = item["edge_index"].to(device)
191
+ y = item["y"].to(device)
192
+ optimizer.zero_grad()
193
+ out = model(x, ei)
194
+ loss = F.binary_cross_entropy(out.squeeze(), y.squeeze())
195
+ loss.backward()
196
+ optimizer.step()
197
+ total_loss += loss.item()
198
+ logger.info(f"GNN incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(dataset):.4f}")
199
+
200
+ # Post-update accuracy
201
+ model.eval()
202
+ post_correct = 0
203
+ with torch.no_grad():
204
+ for item in dataset:
205
+ out = model(item["x"].to(device), item["edge_index"].to(device))
206
+ pred = 1 if out.squeeze().item() >= 0.5 else 0
207
+ post_correct += int(pred == int(item["y"].item()))
208
+ post_acc = post_correct / len(dataset)
209
+
210
+ delta = post_acc - pre_acc
211
+ self._model = model
212
+
213
+ # Save weights
214
+ torch.save(model.state_dict(), self._weights_path)
215
+ logger.info(f"GNN incremental update: {pre_acc:.4f} β†’ {post_acc:.4f} (Ξ”={delta:+.4f})")
216
+
217
+ # Update replay buffer (rolling 500)
218
+ try:
219
+ existing = []
220
+ if buf_path.exists():
221
+ existing = torch.load(buf_path, map_location="cpu", weights_only=False)
222
+ combined = existing + new_graphs
223
+ if len(combined) > 500:
224
+ combined = combined[-500:]
225
+ buf_path.parent.mkdir(parents=True, exist_ok=True)
226
+ torch.save(combined, buf_path)
227
+ except Exception as e:
228
+ logger.warning(f"Replay buffer update failed: {e}")
229
+
230
+ return round(delta, 4)
231
+
232
+ except Exception as e:
233
+ logger.error(f"GNN incremental update failed: {e}")
234
+ return None
235
+
236
+ @property
237
+ def is_loaded(self) -> bool:
238
+ return self._loaded
239
+
240
+
241
+ # ── Legacy compatibility functions ───────────────────────────────────
242
+ _inference = GNNInference()
243
+
244
+
245
+ def analyze_url_with_gnn(url: str, related_urls: list = None) -> dict:
246
+ """Legacy wrapper for backward compatibility."""
247
+ if not _inference.is_loaded:
248
+ _inference.load()
249
+
250
+ if not _inference.is_loaded:
251
+ return {
252
+ "gnn_phish_prob": None,
253
+ "tier3_status": "model_not_loaded",
254
+ "node_count": 0,
255
+ "edge_count": 0,
256
+ "graph_suspicious": False,
257
+ }
258
+
259
+ prob = _inference.predict(url, related_urls)
260
+ return {
261
+ "gnn_phish_prob": prob,
262
+ "node_count": 1 + len(related_urls or []),
263
+ "edge_count": 0,
264
+ "graph_suspicious": prob > 0.6,
265
+ }
266
+
267
+
268
+ def reload_model(new_weights_path: str = None) -> bool:
269
+ path = Path(new_weights_path) if new_weights_path else None
270
+ return _inference.reload(path)
271
+
272
+
273
+ def is_model_loaded() -> bool:
274
+ return _inference.is_loaded
gnn_model.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - gnn/gnn_model.py
3
+ # GNN + MLP model definitions for phishing graph classification.
4
+ #
5
+ # PhishGNN: 3-layer GCN with global_mean_pool β†’ Linear β†’ Sigmoid
6
+ # GCNConv(12β†’64) β†’ ReLU β†’ GCNConv(64β†’32) β†’ ReLU β†’
7
+ # GCNConv(32β†’16) β†’ global_mean_pool β†’ Linear(16β†’1) β†’ Sigmoid
8
+ #
9
+ # PhishMLP: Fallback for single URL or when torch_geometric unavailable
10
+ # Linear(12β†’64) β†’ ReLU β†’ Dropout(0.3) β†’ Linear(64β†’1) β†’ Sigmoid
11
+ # ============================================================
12
+
13
+ from __future__ import annotations
14
+
15
+ import os
16
+ import logging
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ logger = logging.getLogger("phishguard.gnn.model")
24
+
25
+ INPUT_DIM: int = 12 # 12-dim node features
26
+ HIDDEN_DIM: int = 64
27
+ OUTPUT_DIM: int = 1 # binary: sigmoid output
28
+
29
+ # ── Try importing PyTorch Geometric ──────────────────────────────────
30
+ PYGEOM_AVAILABLE: bool = False
31
+ try:
32
+ from torch_geometric.nn import GCNConv, global_mean_pool
33
+ PYGEOM_AVAILABLE = True
34
+ logger.info("PyTorch Geometric found β€” using full GCN model")
35
+ except ImportError:
36
+ PYGEOM_AVAILABLE = False
37
+ logger.info("PyTorch Geometric not found β€” using MLP fallback")
38
+
39
+
40
+ # ── PhishGNN: Full 3-layer Graph Convolutional Network ───────────────
41
+ if PYGEOM_AVAILABLE:
42
+ class PhishGNN(nn.Module):
43
+ """
44
+ 3-layer GCN for graph-level phishing classification.
45
+ Architecture from spec:
46
+ GCNConv(12β†’64) β†’ ReLU β†’ GCNConv(64β†’32) β†’ ReLU β†’
47
+ GCNConv(32β†’16) β†’ global_mean_pool β†’ Linear(16β†’1) β†’ Sigmoid
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ in_channels: int = INPUT_DIM,
53
+ hidden: int = HIDDEN_DIM,
54
+ out_channels: int = OUTPUT_DIM,
55
+ ) -> None:
56
+ super().__init__()
57
+ self.conv1 = GCNConv(in_channels, hidden) # 12 β†’ 64
58
+ self.conv2 = GCNConv(hidden, hidden // 2) # 64 β†’ 32
59
+ self.conv3 = GCNConv(hidden // 2, hidden // 4) # 32 β†’ 16
60
+ self.fc = nn.Linear(hidden // 4, out_channels) # 16 β†’ 1
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ edge_index: torch.Tensor,
66
+ batch: Optional[torch.Tensor] = None,
67
+ ) -> torch.Tensor:
68
+ # Handle empty edge_index
69
+ if edge_index.numel() == 0:
70
+ edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device)
71
+
72
+ x = F.relu(self.conv1(x, edge_index))
73
+ x = F.relu(self.conv2(x, edge_index))
74
+ x = F.relu(self.conv3(x, edge_index))
75
+
76
+ if batch is None:
77
+ batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
78
+
79
+ x = global_mean_pool(x, batch) # (batch_size, 16)
80
+ x = self.fc(x) # (batch_size, 1)
81
+ return torch.sigmoid(x) # [0, 1]
82
+
83
+ def predict_proba(
84
+ self,
85
+ x: torch.Tensor,
86
+ edge_index: torch.Tensor,
87
+ batch: Optional[torch.Tensor] = None,
88
+ ) -> float:
89
+ """Return P_gnn ∈ [0,1] β€” probability of phishing."""
90
+ self.eval()
91
+ with torch.no_grad():
92
+ output = self.forward(x, edge_index, batch)
93
+ return output.squeeze().item()
94
+
95
+
96
+ # ── PhishMLP: Fallback for single URL or no torch_geometric ──────────
97
+ class PhishMLP(nn.Module):
98
+ """
99
+ MLP fallback for phishing classification.
100
+ Used when torch_geometric is unavailable or graph has < 2 nodes.
101
+ Architecture: Linear(12β†’64) β†’ ReLU β†’ Dropout(0.3) β†’ Linear(64β†’1) β†’ Sigmoid
102
+ """
103
+
104
+ def __init__(self, in_channels: int = INPUT_DIM) -> None:
105
+ super().__init__()
106
+ self.net = nn.Sequential(
107
+ nn.Linear(in_channels, 64),
108
+ nn.ReLU(),
109
+ nn.Dropout(0.3),
110
+ nn.Linear(64, 1),
111
+ )
112
+
113
+ def forward(
114
+ self,
115
+ x: torch.Tensor,
116
+ edge_index: Optional[torch.Tensor] = None,
117
+ batch: Optional[torch.Tensor] = None,
118
+ ) -> torch.Tensor:
119
+ # Pool all node features to single vector via mean
120
+ if x.dim() == 2 and x.size(0) > 1:
121
+ x = x.mean(dim=0, keepdim=True)
122
+ elif x.dim() == 1:
123
+ x = x.unsqueeze(0)
124
+ out = self.net(x)
125
+ return torch.sigmoid(out)
126
+
127
+ def predict_proba(
128
+ self,
129
+ x: torch.Tensor,
130
+ edge_index: Optional[torch.Tensor] = None,
131
+ batch: Optional[torch.Tensor] = None,
132
+ ) -> float:
133
+ """Return P_gnn ∈ [0,1] β€” probability of phishing."""
134
+ self.eval()
135
+ with torch.no_grad():
136
+ output = self.forward(x, edge_index, batch)
137
+ return output.squeeze().item()
138
+
139
+
140
+ # ── Model loading utility ────────────────────────────────────────────
141
+ def load_gnn_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
142
+ """
143
+ Load GNN or MLP model with optional trained weights.
144
+ Returns model in eval mode, or None if creation fails.
145
+ """
146
+ model: Optional[nn.Module] = None
147
+
148
+ try:
149
+ model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
150
+ except Exception as e:
151
+ logger.error(f"GNN model creation failed: {e}")
152
+ try:
153
+ model = PhishMLP()
154
+ except Exception as e2:
155
+ logger.error(f"MLP fallback creation also failed: {e2}")
156
+ return None
157
+
158
+ if model_path and os.path.exists(model_path):
159
+ try:
160
+ state = torch.load(model_path, map_location="cpu", weights_only=True)
161
+ model.load_state_dict(state)
162
+ logger.info(f"GNN weights loaded from {model_path}")
163
+ except RuntimeError as e:
164
+ logger.warning(f"GNN weights mismatch (architecture changed?): {e}")
165
+ except Exception as e:
166
+ logger.warning(f"GNN weight load failed: {e}")
167
+ elif model_path:
168
+ logger.info(f"GNN weights file not found: {model_path}")
169
+ else:
170
+ logger.info("No GNN weights path β€” using untrained model")
171
+
172
+ try:
173
+ model.eval()
174
+ except Exception as e:
175
+ logger.error(f"GNN eval() failed: {e}")
176
+ return None
177
+
178
+ return model
179
+
180
+
181
+ # Legacy alias
182
+ def load_model(model_path: Optional[str] = None) -> Optional[nn.Module]:
183
+ return load_gnn_model(model_path)
keep_alive.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - keep_alive.py
3
+ # Keeps your Render.com server awake 24/7.
4
+ #
5
+ # From architecture doc 5.3:
6
+ # Render free tier sleeps after 15 min of inactivity.
7
+ # This pings GET /health every 14 min to prevent that.
8
+ #
9
+ # WHERE TO RUN:
10
+ # - On a second laptop / old computer
11
+ # - On your phone using Termux (free Android app)
12
+ # - On a friend's computer
13
+ # - On your own laptop in a separate terminal (less ideal)
14
+ #
15
+ # HOW TO RUN:
16
+ # python keep_alive.py
17
+ # (keep this terminal window open β€” never close it)
18
+ # ============================================================
19
+
20
+ import time
21
+ import requests
22
+ import datetime
23
+
24
+ # !! CHANGE THIS to your actual Render URL !!
25
+ API_URL = "https://YOUR-APP-NAME.onrender.com/health"
26
+
27
+ INTERVAL = 14 * 60 # 14 minutes in seconds
28
+
29
+ print("=" * 50)
30
+ print("PhishGuard Keep-Alive Script")
31
+ print("=" * 50)
32
+ print(f"Pinging: {API_URL}")
33
+ print(f"Every: 14 minutes")
34
+ print(f"Started: {datetime.datetime.now():%Y-%m-%d %H:%M:%S}")
35
+ print("\nDO NOT close this window!")
36
+ print("Press Ctrl+C to stop.\n")
37
+
38
+ ping_count = 0
39
+ while True:
40
+ try:
41
+ r = requests.get(API_URL, timeout=15)
42
+ ping_count += 1
43
+ status = "OK" if r.status_code == 200 else f"ERROR {r.status_code}"
44
+ print(f"[{datetime.datetime.now():%H:%M:%S}] Ping #{ping_count} β†’ {status}")
45
+ except requests.exceptions.ConnectionError:
46
+ print(f"[{datetime.datetime.now():%H:%M:%S}] Connection failed β€” server might be waking up...")
47
+ except Exception as e:
48
+ print(f"[{datetime.datetime.now():%H:%M:%S}] Error: {e}")
49
+
50
+ time.sleep(INTERVAL)
main.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - main.py
3
+ # FastAPI orchestrator β€” Full 4-tier phishing detection pipeline
4
+ # with feedback-driven incremental retraining.
5
+ #
6
+ # Endpoints:
7
+ # POST /analyze β†’ 4-tier URL phishing analysis
8
+ # POST /analyze/email β†’ BERT-only email body analysis
9
+ # POST /retrain β†’ Incremental model retraining
10
+ # GET /model_version β†’ Current model version info
11
+ # GET /health β†’ All model load statuses
12
+ #
13
+ # Architecture:
14
+ # Tier 1: Whitelist O(1) β†’ SAFE exit (~55% traffic)
15
+ # Tier 2: Heuristic 15 signals β†’ BLOCK if >= 80 (~15% blocked)
16
+ # Tier 3: BERT+GNN parallel β†’ BLOCK/SAFE/escalate (~15% exits)
17
+ # Tier 4: CNN visual + brand hash β†’ BLOCK/SAFE (~15% borderline)
18
+ # ============================================================
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ import sys
24
+ import asyncio
25
+ import time
26
+ import hashlib
27
+ import logging
28
+ import logging.handlers
29
+ from collections import OrderedDict
30
+ from contextlib import asynccontextmanager
31
+ from pathlib import Path
32
+ from typing import List, Optional
33
+
34
+ from fastapi import FastAPI
35
+ from fastapi.middleware.cors import CORSMiddleware
36
+ from pydantic import BaseModel
37
+
38
+ # ── Path setup ────────────────────────────────────────────────────────
39
+ BASE_DIR = Path(__file__).parent
40
+ for sub_dir in ["gnn", "cnn"]:
41
+ sub_path = BASE_DIR / sub_dir
42
+ if sub_path.is_dir():
43
+ sys.path.insert(0, str(sub_path))
44
+
45
+ # ── Logging ───────────────────────────────────────────────────────────
46
+ log_dir = BASE_DIR / "logs"
47
+ log_dir.mkdir(exist_ok=True)
48
+
49
+ _handler = logging.handlers.RotatingFileHandler(
50
+ log_dir / "phishguard.log",
51
+ maxBytes=5 * 1024 * 1024,
52
+ backupCount=3,
53
+ encoding="utf-8",
54
+ )
55
+ _handler.setFormatter(logging.Formatter(
56
+ "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s",
57
+ datefmt="%Y-%m-%d %H:%M:%S",
58
+ ))
59
+
60
+ logger = logging.getLogger("phishguard")
61
+ logger.setLevel(logging.INFO)
62
+ logger.addHandler(_handler)
63
+ logger.addHandler(logging.StreamHandler())
64
+
65
+
66
+ # ── Import project modules ───────────────────────────────────────────
67
+ from url_heuristics import HeuristicScorer, HeuristicResult
68
+ from bert_analyzer import BERTPhishingClassifier
69
+
70
+ # GNN imports
71
+ GNN_AVAILABLE = False
72
+ gnn_inference = None
73
+ try:
74
+ from gnn.gnn_inference import GNNInference
75
+ GNN_AVAILABLE = True
76
+ except ImportError:
77
+ try:
78
+ from gnn_inference import GNNInference
79
+ GNN_AVAILABLE = True
80
+ except ImportError:
81
+ logger.warning("GNN module not available")
82
+
83
+ # CNN imports
84
+ CNN_AVAILABLE = False
85
+ cnn_inference = None
86
+ brand_detector = None
87
+ try:
88
+ from cnn.cnn_inference import CNNInference
89
+ from cnn.screenshot_hasher import BrandHashDetector
90
+ from cnn.cnn_model import preprocess_screenshot
91
+ CNN_AVAILABLE = True
92
+ except ImportError:
93
+ try:
94
+ from cnn_inference import CNNInference
95
+ from screenshot_hasher import BrandHashDetector
96
+ from cnn_model import preprocess_screenshot
97
+ CNN_AVAILABLE = True
98
+ except ImportError:
99
+ logger.warning("CNN module not available")
100
+
101
+ from tier3_bert_gnn import Tier3Ensemble
102
+ from retraining_service import RetrainingService, FeedbackRecord, RetrainResult
103
+
104
+
105
+ # ── Whitelist (Tier 1) ────────────────────────────────────────────────
106
+ WHITELIST: set[str] = {
107
+ "google.com", "youtube.com", "facebook.com", "amazon.com", "wikipedia.org",
108
+ "twitter.com", "instagram.com", "linkedin.com", "microsoft.com", "apple.com",
109
+ "github.com", "stackoverflow.com", "reddit.com", "netflix.com", "paypal.com",
110
+ "bankofamerica.com", "chase.com", "wellsfargo.com", "yahoo.com", "bing.com",
111
+ "outlook.com", "office.com", "live.com", "adobe.com", "dropbox.com",
112
+ "zoom.us", "slack.com", "spotify.com", "twitch.tv", "ebay.com",
113
+ "walmart.com", "target.com", "bestbuy.com", "airbnb.com",
114
+ "x.com", "tiktok.com", "pinterest.com", "quora.com", "medium.com",
115
+ }
116
+
117
+
118
+ def get_root_domain(url: str) -> str:
119
+ """Extract root domain from a URL."""
120
+ from urllib.parse import urlparse
121
+ try:
122
+ host = urlparse(url).hostname or ""
123
+ host = host.replace("www.", "")
124
+ parts = host.split(".")
125
+ return ".".join(parts[-2:]) if len(parts) >= 2 else host
126
+ except Exception:
127
+ return ""
128
+
129
+
130
+ # ── URL Cache (LRU, 30-min TTL) ──────────────────────────────────────
131
+ CACHE_TTL = 30 * 60
132
+ CACHE_MAX = 500
133
+
134
+
135
+ class URLCache:
136
+ def __init__(self, maxsize: int = CACHE_MAX, ttl: int = CACHE_TTL) -> None:
137
+ self._cache: OrderedDict = OrderedDict()
138
+ self._maxsize = maxsize
139
+ self._ttl = ttl
140
+
141
+ def get(self, url: str) -> Optional[dict]:
142
+ if url in self._cache:
143
+ entry = self._cache[url]
144
+ if time.time() - entry["ts"] < self._ttl:
145
+ self._cache.move_to_end(url)
146
+ return entry["result"]
147
+ else:
148
+ del self._cache[url]
149
+ return None
150
+
151
+ def set(self, url: str, result: dict) -> None:
152
+ self._cache[url] = {"result": result, "ts": time.time()}
153
+ self._cache.move_to_end(url)
154
+ if len(self._cache) > self._maxsize:
155
+ self._cache.popitem(last=False)
156
+
157
+ def clear(self) -> None:
158
+ self._cache.clear()
159
+
160
+
161
+ _url_cache = URLCache()
162
+
163
+
164
+ # ── Request/Response Models ───────────────────────────────────────────
165
+ class AnalyzeRequest(BaseModel):
166
+ url: str
167
+ heuristic_score: float = 0.0
168
+ page_title: str = ""
169
+ page_snippet: str = ""
170
+ related_urls: list = []
171
+
172
+
173
+ class EmailRequest(BaseModel):
174
+ sender: str
175
+ subject: str = ""
176
+ body: str = ""
177
+ urls: list = []
178
+ timestamp: str = ""
179
+
180
+
181
+ class FeedbackSample(BaseModel):
182
+ url: str
183
+ verdict: str = ""
184
+ confidence: float = 0.0
185
+ tier_used: int = 0
186
+ heuristic_score: int = 0
187
+ signals: list = []
188
+ user_feedback: Optional[str] = None
189
+ timestamp: str = ""
190
+ feedback_ts: Optional[str] = None
191
+ url_hash: str = ""
192
+ session_id: str = ""
193
+
194
+
195
+ class RetrainRequest(BaseModel):
196
+ samples: List[FeedbackSample]
197
+ trigger: str = "count"
198
+ session_id: str = ""
199
+ extension_version: str = ""
200
+
201
+
202
+ # ── Global state ──────────────────────────────────────────────────────
203
+ _scorer: Optional[HeuristicScorer] = None
204
+ _bert: Optional[BERTPhishingClassifier] = None
205
+ _gnn: Optional[GNNInference] = None
206
+ _cnn: Optional[CNNInference] = None
207
+ _brand: Optional[BrandHashDetector] = None
208
+ _tier3: Optional[Tier3Ensemble] = None
209
+ _retrain_service: Optional[RetrainingService] = None
210
+ _retrain_lock = asyncio.Lock()
211
+
212
+
213
+ # ── Lifespan (startup/shutdown) ───────────────────────────────────────
214
+ @asynccontextmanager
215
+ async def lifespan(app: FastAPI):
216
+ """Load all models at startup, clean up at shutdown."""
217
+ global _scorer, _bert, _gnn, _cnn, _brand, _tier3, _retrain_service
218
+
219
+ logger.info("=== PhishGuard AI starting up ===")
220
+
221
+ # Tier 2: Heuristic Scorer
222
+ _scorer = HeuristicScorer()
223
+ logger.info("βœ“ Tier 2: HeuristicScorer initialized")
224
+
225
+ # Tier 3a: BERT
226
+ _bert = BERTPhishingClassifier()
227
+ logger.info("βœ“ Tier 3a: BERT classifier initialized (lazy-load)")
228
+
229
+ # Tier 3b: GNN
230
+ if GNN_AVAILABLE:
231
+ _gnn = GNNInference()
232
+ _gnn.load()
233
+ logger.info(f"βœ“ Tier 3b: GNN loaded={_gnn.is_loaded}")
234
+ else:
235
+ _gnn = None
236
+ logger.warning("βœ— Tier 3b: GNN not available")
237
+
238
+ # Tier 3 Ensemble
239
+ if _gnn:
240
+ _tier3 = Tier3Ensemble(_bert, _gnn)
241
+ logger.info("βœ“ Tier 3: Ensemble initialized")
242
+ else:
243
+ _tier3 = None
244
+ logger.warning("βœ— Tier 3: Ensemble not available (GNN missing)")
245
+
246
+ # Tier 4: CNN + Brand Detection
247
+ if CNN_AVAILABLE:
248
+ _cnn = CNNInference()
249
+ _cnn.load()
250
+ _brand = BrandHashDetector()
251
+ logger.info(f"βœ“ Tier 4: CNN loaded={_cnn.is_loaded}, Brand hash DB loaded")
252
+ else:
253
+ _cnn = None
254
+ _brand = None
255
+ logger.warning("βœ— Tier 4: CNN not available")
256
+
257
+ # Retraining Service
258
+ _retrain_service = RetrainingService(
259
+ bert_classifier=_bert,
260
+ gnn_inference=_gnn or GNNInference(),
261
+ cnn_inference=_cnn or (CNNInference() if CNN_AVAILABLE else None),
262
+ )
263
+ logger.info("βœ“ Retraining service initialized")
264
+ logger.info("=== PhishGuard AI ready ===")
265
+
266
+ yield
267
+
268
+ logger.info("=== PhishGuard AI shutting down ===")
269
+
270
+
271
+ # ── FastAPI App ───────────────────────────────────────────────────────
272
+ app = FastAPI(
273
+ title="PhishGuard AI Backend",
274
+ version="3.0",
275
+ description="4-tier ML phishing detection with feedback-driven retraining",
276
+ lifespan=lifespan,
277
+ )
278
+
279
+ app.add_middleware(
280
+ CORSMiddleware,
281
+ allow_origins=["*"],
282
+ allow_methods=["*"],
283
+ allow_headers=["*"],
284
+ )
285
+
286
+
287
+ # ── POST /analyze β€” Full 4-tier pipeline ──────────────────────────────
288
+ @app.post("/analyze")
289
+ async def analyze_endpoint(req: AnalyzeRequest) -> dict:
290
+ """
291
+ Analyze a URL through the 4-tier phishing detection pipeline.
292
+
293
+ Tier 1: Whitelist β†’ SAFE
294
+ Tier 2: Heuristic β†’ BLOCK if >= 80
295
+ Tier 3: BERT+GNN ensemble β†’ BLOCK/SAFE/escalate
296
+ Tier 4: CNN visual + brand hash β†’ BLOCK/SAFE
297
+ """
298
+ url = req.url
299
+ details: dict = {}
300
+
301
+ # ── TIER 1: Whitelist ────────────────────────────────────────
302
+ root = get_root_domain(url)
303
+ if root in WHITELIST:
304
+ return {
305
+ "url": url,
306
+ "is_phishing": False,
307
+ "confidence": 0.0,
308
+ "method": "whitelist",
309
+ "status": "safe",
310
+ "tier": 1,
311
+ "heuristic_score": 0,
312
+ "signals": [],
313
+ "details": {"whitelisted_domain": root},
314
+ }
315
+
316
+ # ── Cache check ──────────────────────────────────────────────
317
+ cached = _url_cache.get(url)
318
+ if cached is not None:
319
+ return cached
320
+
321
+ # ── TIER 2: Heuristic scoring ────────────────────────────────
322
+ h_result: HeuristicResult = _scorer.score(url)
323
+
324
+ # Use the higher of server-side and browser-side heuristic scores
325
+ h_score = max(h_result.score, int(req.heuristic_score))
326
+ details["heuristic"] = {
327
+ "score": h_result.score,
328
+ "raw_score": h_result.raw_score,
329
+ "signals": h_result.signals,
330
+ "browser_score": int(req.heuristic_score),
331
+ "combined_score": h_score,
332
+ }
333
+
334
+ if h_score >= 80:
335
+ result = {
336
+ "url": url,
337
+ "is_phishing": True,
338
+ "confidence": h_score / 100.0,
339
+ "method": "heuristic",
340
+ "status": "blocked",
341
+ "tier": 2,
342
+ "heuristic_score": h_score,
343
+ "signals": h_result.signals,
344
+ "details": details,
345
+ }
346
+ _url_cache.set(url, result)
347
+ logger.info(f"Tier 2 BLOCK | url={url[:60]} | score={h_score}")
348
+ return result
349
+
350
+ # ── TIER 3: BERT + GNN Ensemble ──────────────────────────────
351
+ if _tier3 is not None:
352
+ try:
353
+ p3 = await _tier3.predict(
354
+ url=url,
355
+ title=req.page_title,
356
+ snippet=req.page_snippet,
357
+ h_score=h_score,
358
+ )
359
+ details["tier3_score"] = p3
360
+ except Exception as e:
361
+ logger.error(f"Tier 3 error: {e}")
362
+ p3 = h_score / 100.0 # fallback to heuristic
363
+ details["tier3_error"] = str(e)
364
+ else:
365
+ # Tier 3 unavailable β€” use BERT alone + heuristic
366
+ if _bert is not None:
367
+ loop = asyncio.get_event_loop()
368
+ try:
369
+ p_bert = await loop.run_in_executor(
370
+ None, _bert.predict, url, req.page_title, req.page_snippet,
371
+ )
372
+ except Exception:
373
+ p_bert = 0.5
374
+ h_norm = h_score / 100.0
375
+ p3 = 0.60 * p_bert + 0.40 * h_norm
376
+ else:
377
+ p3 = h_score / 100.0
378
+ details["tier3_score"] = p3
379
+ details["tier3_note"] = "ensemble_unavailable"
380
+
381
+ # Tier 3 decision
382
+ decision = Tier3Ensemble.decide(p3)
383
+
384
+ if decision == "block":
385
+ result = {
386
+ "url": url,
387
+ "is_phishing": True,
388
+ "confidence": round(p3, 4),
389
+ "method": "bert_gnn_ensemble",
390
+ "status": "blocked",
391
+ "tier": 3,
392
+ "heuristic_score": h_score,
393
+ "signals": h_result.signals,
394
+ "details": details,
395
+ }
396
+ _url_cache.set(url, result)
397
+ logger.info(f"Tier 3 BLOCK | url={url[:60]} | P3={p3:.4f}")
398
+ return result
399
+
400
+ if decision == "safe":
401
+ result = {
402
+ "url": url,
403
+ "is_phishing": False,
404
+ "confidence": round(p3, 4),
405
+ "method": "bert_gnn_ensemble",
406
+ "status": "safe",
407
+ "tier": 3,
408
+ "heuristic_score": h_score,
409
+ "signals": h_result.signals,
410
+ "details": details,
411
+ }
412
+ _url_cache.set(url, result)
413
+ logger.info(f"Tier 3 SAFE | url={url[:60]} | P3={p3:.4f}")
414
+ return result
415
+
416
+ # ── TIER 4: CNN Visual + Brand Hash (borderline 0.40 ≀ P3 < 0.85)
417
+ if _cnn is not None and _cnn.is_loaded:
418
+ try:
419
+ # Capture screenshot
420
+ screenshot_bytes = await _capture_screenshot_for_tier4(url)
421
+
422
+ if screenshot_bytes:
423
+ # CNN prediction
424
+ p_cnn = _cnn.predict(screenshot_bytes)
425
+ details["cnn_prob"] = round(p_cnn, 4)
426
+
427
+ # Brand hash check
428
+ brand_boost = 0.0
429
+ if _brand is not None:
430
+ is_impersonation, brand_name, brand_conf = _brand.detect(
431
+ screenshot_bytes, url
432
+ )
433
+ details["brand"] = {
434
+ "impersonation_detected": is_impersonation,
435
+ "brand": brand_name,
436
+ "confidence": round(brand_conf, 3),
437
+ }
438
+ if is_impersonation:
439
+ brand_boost = 0.25
440
+
441
+ # P_final = 0.55Β·P3 + 0.30Β·P_cnn + brand_boost
442
+ p_final = min((0.55 * p3) + (0.30 * p_cnn) + brand_boost, 1.0)
443
+ details["tier4_score"] = round(p_final, 4)
444
+
445
+ is_phishing = p_final >= 0.65
446
+ result = {
447
+ "url": url,
448
+ "is_phishing": is_phishing,
449
+ "confidence": round(p_final, 4),
450
+ "method": "full_ensemble_bert_gnn_cnn",
451
+ "status": "blocked" if is_phishing else "safe",
452
+ "tier": 4,
453
+ "heuristic_score": h_score,
454
+ "signals": h_result.signals,
455
+ "details": details,
456
+ }
457
+ _url_cache.set(url, result)
458
+ logger.info(f"Tier 4 {'BLOCK' if is_phishing else 'SAFE'} | url={url[:60]} | P_final={p_final:.4f}")
459
+ return result
460
+
461
+ except Exception as e:
462
+ logger.error(f"Tier 4 error: {e}")
463
+ details["tier4_error"] = str(e)
464
+
465
+ # Tier 4 unavailable/failed β€” use Tier 3 score with conservative threshold
466
+ is_phishing = p3 >= 0.65
467
+ result = {
468
+ "url": url,
469
+ "is_phishing": is_phishing,
470
+ "confidence": round(p3, 4),
471
+ "method": "bert_gnn_ensemble",
472
+ "status": "blocked" if is_phishing else "safe",
473
+ "tier": 3,
474
+ "heuristic_score": h_score,
475
+ "signals": h_result.signals,
476
+ "details": details,
477
+ }
478
+ _url_cache.set(url, result)
479
+ logger.info(f"Tier 4 fallback β†’ Tier 3 | url={url[:60]} | P3={p3:.4f}")
480
+ return result
481
+
482
+
483
+ async def _capture_screenshot_for_tier4(url: str) -> Optional[bytes]:
484
+ """Capture screenshot for Tier 4 CNN analysis."""
485
+ try:
486
+ from playwright.async_api import async_playwright
487
+
488
+ async with async_playwright() as p:
489
+ browser = await p.chromium.launch(headless=True)
490
+ page = await browser.new_page(
491
+ viewport={"width": 1280, "height": 800},
492
+ user_agent=(
493
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
494
+ "AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36"
495
+ ),
496
+ )
497
+
498
+ # Block heavy resources
499
+ await page.route(
500
+ "**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,wav,mp3}",
501
+ lambda route: route.abort(),
502
+ )
503
+
504
+ await page.goto(url, wait_until="domcontentloaded", timeout=10000)
505
+ screenshot = await page.screenshot(type="png")
506
+ await browser.close()
507
+ return screenshot
508
+
509
+ except Exception as e:
510
+ logger.warning(f"Tier 4 screenshot failed: {e}")
511
+ return None
512
+
513
+
514
+ # ── POST /analyze/email ───────────────────────────────────────────────
515
+ @app.post("/analyze/email")
516
+ async def analyze_email_endpoint(req: EmailRequest) -> dict:
517
+ """BERT-only path for email body text analysis."""
518
+ # Sender whitelist check
519
+ sender_domain = req.sender.split("@")[-1].lower() if "@" in req.sender else ""
520
+ if sender_domain in WHITELIST:
521
+ return {
522
+ "status": "safe",
523
+ "analysis": {
524
+ "isPhishing": False,
525
+ "probability": 0.0,
526
+ "reason": "Trusted sender domain",
527
+ },
528
+ }
529
+
530
+ # Analyze embedded URLs
531
+ MAX_URLS = 3
532
+ urls_to_check = req.urls[:MAX_URLS]
533
+
534
+ if not urls_to_check:
535
+ # Text-only analysis
536
+ if _bert:
537
+ combined = f"{req.subject} {req.body}"
538
+ prob = _bert.predict(combined, req.subject, req.body)
539
+ is_phishing = prob > 0.6
540
+ return {
541
+ "status": "blocked" if is_phishing else "safe",
542
+ "analysis": {
543
+ "isPhishing": is_phishing,
544
+ "probability": prob,
545
+ "reason": "BERT text analysis (no URLs)",
546
+ },
547
+ }
548
+ return {
549
+ "status": "safe",
550
+ "analysis": {
551
+ "isPhishing": False,
552
+ "probability": 0.1,
553
+ "reason": "No URLs and no ML model available",
554
+ },
555
+ }
556
+
557
+ # Analyze URLs through the main pipeline
558
+ tasks = [
559
+ analyze_endpoint(AnalyzeRequest(url=u, page_title=req.subject))
560
+ for u in urls_to_check
561
+ ]
562
+ results = await asyncio.gather(*tasks, return_exceptions=True)
563
+
564
+ max_prob = 0.0
565
+ phishing_detected = False
566
+ flagged_urls = []
567
+
568
+ for idx, r in enumerate(results):
569
+ if isinstance(r, Exception):
570
+ continue
571
+ prob = r.get("confidence", 0.0)
572
+ max_prob = max(max_prob, prob)
573
+ if r.get("is_phishing"):
574
+ phishing_detected = True
575
+ flagged_urls.append(r.get("url", urls_to_check[idx]))
576
+
577
+ return {
578
+ "status": "blocked" if phishing_detected else "safe",
579
+ "analysis": {
580
+ "isPhishing": phishing_detected,
581
+ "probability": max_prob,
582
+ "flagged_urls": flagged_urls,
583
+ "reason": "URL analysis via ML ensemble",
584
+ },
585
+ }
586
+
587
+
588
+ # ── POST /retrain β€” Incremental retraining ────────────────────────────
589
+ @app.post("/retrain")
590
+ async def retrain_endpoint(req: RetrainRequest) -> dict:
591
+ """
592
+ Receive labeled feedback and incrementally update all models.
593
+ Uses asyncio.Lock() to prevent concurrent retraining jobs.
594
+ Timeout: 600s max.
595
+ """
596
+ if _retrain_service is None:
597
+ return {"status": "error", "message": "Retraining service not initialized"}
598
+
599
+ # Prevent concurrent retraining
600
+ if _retrain_lock.locked():
601
+ return {
602
+ "status": "skipped",
603
+ "message": "Retraining already in progress",
604
+ "models_updated": [],
605
+ }
606
+
607
+ async with _retrain_lock:
608
+ # Convert Pydantic models to FeedbackRecord dataclasses
609
+ records = [
610
+ FeedbackRecord(
611
+ url=s.url,
612
+ verdict=s.verdict,
613
+ confidence=s.confidence,
614
+ tier_used=s.tier_used,
615
+ heuristic_score=s.heuristic_score,
616
+ signals=s.signals,
617
+ user_feedback=s.user_feedback,
618
+ timestamp=s.timestamp,
619
+ feedback_ts=s.feedback_ts,
620
+ url_hash=s.url_hash,
621
+ session_id=s.session_id,
622
+ )
623
+ for s in req.samples
624
+ ]
625
+
626
+ try:
627
+ result = await asyncio.wait_for(
628
+ _retrain_service.retrain(records),
629
+ timeout=600,
630
+ )
631
+
632
+ # Clear URL cache after retraining (stale results)
633
+ if result.status == "success":
634
+ _url_cache.clear()
635
+
636
+ return {
637
+ "status": result.status,
638
+ "models_updated": result.models_updated,
639
+ "samples_used": result.samples_used,
640
+ "duration_seconds": result.duration_seconds,
641
+ "accuracy_delta": result.accuracy_delta,
642
+ "next_retrain_hint": result.next_retrain_hint,
643
+ }
644
+
645
+ except asyncio.TimeoutError:
646
+ return {
647
+ "status": "error",
648
+ "message": "Retraining timed out (600s limit)",
649
+ }
650
+ except Exception as e:
651
+ logger.error(f"Retrain endpoint error: {e}")
652
+ return {
653
+ "status": "error",
654
+ "message": str(e),
655
+ }
656
+
657
+
658
+ # ── GET /model_version ────────────────────────────────────────────────
659
+ @app.get("/model_version")
660
+ async def model_version_endpoint() -> dict:
661
+ """Return current model version info for extension polling."""
662
+ if _retrain_service:
663
+ return _retrain_service.get_version_info()
664
+ return {"version": 0, "updated_at": None, "accuracy": {}}
665
+
666
+
667
+ # ── GET /health ───────────────────────────────────────────────────────
668
+ @app.get("/health")
669
+ async def health_endpoint() -> dict:
670
+ """Liveness probe with per-tier readiness and model statuses."""
671
+ return {
672
+ "status": "ok",
673
+ "version": "3.0",
674
+ "tier1": True,
675
+ "tier2": _scorer is not None,
676
+ "tier3": _tier3 is not None,
677
+ "tier4": _cnn is not None and _cnn.is_loaded if _cnn else False,
678
+ "retraining_in_progress": _retrain_lock.locked(),
679
+ "model_version": _retrain_service.model_version if _retrain_service else 0,
680
+ "modules": {
681
+ "heuristic": _scorer is not None,
682
+ "bert": _bert is not None and _bert.is_loaded,
683
+ "bert_lazy": _bert is not None and not _bert.is_loaded,
684
+ "gnn": _gnn is not None and _gnn.is_loaded if _gnn else False,
685
+ "cnn": _cnn is not None and _cnn.is_loaded if _cnn else False,
686
+ "brand_hash": _brand is not None,
687
+ },
688
+ }
689
+
690
+
691
+ # ── Legacy feedback endpoint (backward compat) ───────────────────────
692
+ @app.post("/feedback")
693
+ async def legacy_feedback_endpoint(req: dict) -> dict:
694
+ """Legacy feedback endpoint for backward compatibility."""
695
+ return {"status": "success", "message": "Use POST /retrain for feedback-driven retraining"}
696
+
697
+
698
+ # ── Run directly ──────────────────────────────────────────────────────
699
+ # uvicorn main:app --reload --port 8000
manifest.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "manifest_version": 3,
3
+ "name": "PhishGuard AI",
4
+ "version": "3.0",
5
+ "description": "Adaptive ML-based phishing detection with feedback-driven retraining β€” BERT + GNN + CNN ensemble",
6
+ "permissions": [
7
+ "tabs",
8
+ "storage",
9
+ "webNavigation",
10
+ "alarms",
11
+ "notifications",
12
+ "activeTab",
13
+ "scripting"
14
+ ],
15
+ "host_permissions": [
16
+ "<all_urls>"
17
+ ],
18
+ "background": {
19
+ "service_worker": "background.js"
20
+ },
21
+ "content_scripts": [
22
+ {
23
+ "matches": ["<all_urls>"],
24
+ "js": ["content.js"],
25
+ "run_at": "document_idle"
26
+ }
27
+ ],
28
+ "action": {
29
+ "default_popup": "popup.html",
30
+ "default_title": "PhishGuard AI"
31
+ },
32
+ "icons": {
33
+ "16": "icons/icon16.png",
34
+ "48": "icons/icon48.png",
35
+ "128": "icons/icon128.png"
36
+ }
37
+ }
popup.html ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>PhishGuard AI</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
9
+
10
+ * { margin: 0; padding: 0; box-sizing: border-box; }
11
+
12
+ :root {
13
+ --bg-primary: #0F0F14;
14
+ --bg-secondary: #1A1A24;
15
+ --bg-card: #22222E;
16
+ --bg-hover: #2A2A38;
17
+ --text-primary: #EAEAF0;
18
+ --text-secondary: #8888A0;
19
+ --text-muted: #5A5A72;
20
+ --accent: #534AB7;
21
+ --accent-glow: rgba(83, 74, 183, 0.35);
22
+ --safe: #22C55E;
23
+ --safe-glow: rgba(34, 197, 94, 0.25);
24
+ --danger: #EF4444;
25
+ --danger-glow: rgba(239, 68, 68, 0.25);
26
+ --warning: #F59E0B;
27
+ --warning-glow: rgba(245, 158, 11, 0.25);
28
+ --border: rgba(255,255,255,0.06);
29
+ --radius: 12px;
30
+ --radius-sm: 8px;
31
+ }
32
+
33
+ body {
34
+ width: 380px;
35
+ min-height: 480px;
36
+ max-height: 640px;
37
+ font-family: 'Inter', -apple-system, sans-serif;
38
+ background: var(--bg-primary);
39
+ color: var(--text-primary);
40
+ overflow-y: auto;
41
+ scrollbar-width: thin;
42
+ scrollbar-color: var(--bg-hover) transparent;
43
+ }
44
+ body::-webkit-scrollbar { width: 4px; }
45
+ body::-webkit-scrollbar-thumb { background: var(--bg-hover); border-radius: 4px; }
46
+
47
+ /* ── Header ──────────────────────────────────────────── */
48
+ .header {
49
+ display: flex; align-items: center; gap: 10px;
50
+ padding: 14px 20px 10px;
51
+ border-bottom: 1px solid var(--border);
52
+ }
53
+ .header-logo {
54
+ width: 28px; height: 28px;
55
+ background: linear-gradient(135deg, var(--accent), #7C6BDB);
56
+ border-radius: var(--radius-sm);
57
+ display: flex; align-items: center; justify-content: center;
58
+ font-size: 14px;
59
+ }
60
+ .header h1 { font-size: 15px; font-weight: 700; letter-spacing: -0.3px; }
61
+ .header h1 span { color: var(--accent); }
62
+ .header-badge {
63
+ margin-left: auto;
64
+ font-size: 10px; padding: 3px 8px;
65
+ background: var(--bg-card); border: 1px solid var(--border);
66
+ border-radius: 20px; color: var(--text-secondary); font-weight: 500;
67
+ }
68
+
69
+ /* ── URL Bar ──────────────────────────────────────────── */
70
+ .url-bar {
71
+ padding: 8px 20px;
72
+ background: var(--bg-secondary);
73
+ border-bottom: 1px solid var(--border);
74
+ }
75
+ .url-text {
76
+ font-size: 11px; color: var(--text-muted);
77
+ overflow: hidden; text-overflow: ellipsis; white-space: nowrap;
78
+ font-family: 'SF Mono', 'Fira Code', monospace;
79
+ }
80
+
81
+ /* ── Loading ──────────────────────────────────────────── */
82
+ .loading-container {
83
+ display: flex; flex-direction: column; align-items: center;
84
+ justify-content: center; padding: 40px 20px; gap: 14px;
85
+ }
86
+ .spinner {
87
+ width: 40px; height: 40px;
88
+ border: 3px solid var(--bg-hover);
89
+ border-top-color: var(--accent);
90
+ border-radius: 50%;
91
+ animation: spin 0.8s linear infinite;
92
+ }
93
+ @keyframes spin { to { transform: rotate(360deg); } }
94
+ .loading-text {
95
+ font-size: 13px; color: var(--text-secondary);
96
+ animation: pulse 1.5s ease-in-out infinite;
97
+ }
98
+ @keyframes pulse { 0%,100% { opacity: 1; } 50% { opacity: 0.5; } }
99
+
100
+ /* ── Result Panel ────────────────────────────────────── */
101
+ .result-panel { padding: 16px 20px; }
102
+
103
+ .result-hero {
104
+ display: flex; align-items: center; gap: 16px;
105
+ margin-bottom: 16px;
106
+ }
107
+
108
+ .score-ring-wrap {
109
+ position: relative; width: 80px; height: 80px; flex-shrink: 0;
110
+ }
111
+ .score-ring-bg, .score-ring-fg {
112
+ fill: none; stroke-width: 6;
113
+ }
114
+ .score-ring-bg { stroke: var(--bg-hover); }
115
+ .score-ring-fg {
116
+ stroke-linecap: round;
117
+ transform: rotate(-90deg); transform-origin: center;
118
+ transition: stroke-dashoffset 1s ease, stroke 0.5s;
119
+ stroke-dasharray: 213; stroke-dashoffset: 213;
120
+ }
121
+ .score-label {
122
+ position: absolute; inset: 0;
123
+ display: flex; flex-direction: column;
124
+ align-items: center; justify-content: center;
125
+ }
126
+ .score-pct { font-size: 20px; font-weight: 700; line-height: 1; }
127
+ .score-sub {
128
+ font-size: 9px; color: var(--text-muted);
129
+ margin-top: 2px; text-transform: uppercase; letter-spacing: 0.5px;
130
+ }
131
+
132
+ .shield-icon {
133
+ font-size: 28px;
134
+ animation: shieldPop 0.6s cubic-bezier(0.34, 1.56, 0.64, 1);
135
+ }
136
+ @keyframes shieldPop {
137
+ 0% { transform: scale(0.3) rotate(-15deg); opacity: 0; }
138
+ 60% { transform: scale(1.15) rotate(3deg); }
139
+ 100% { transform: scale(1) rotate(0); opacity: 1; }
140
+ }
141
+
142
+ .result-verdict { flex: 1; }
143
+ .verdict-label { font-size: 16px; font-weight: 700; line-height: 1.2; }
144
+ .verdict-detail { font-size: 11px; color: var(--text-secondary); margin-top: 3px; }
145
+
146
+ .status-safe { color: var(--safe); }
147
+ .status-danger { color: var(--danger); }
148
+ .status-warn { color: var(--warning); }
149
+
150
+ /* ── Tier Rows ───────────────────────────────────────── */
151
+ .tier-section { margin-top: 4px; }
152
+ .tier-row {
153
+ background: var(--bg-card);
154
+ border: 1px solid var(--border);
155
+ border-radius: var(--radius-sm);
156
+ margin-bottom: 5px; overflow: hidden;
157
+ transition: border-color 0.2s;
158
+ }
159
+ .tier-row:hover { border-color: rgba(255,255,255,0.1); }
160
+ .tier-header {
161
+ display: flex; align-items: center;
162
+ padding: 8px 12px; cursor: pointer;
163
+ user-select: none; gap: 8px;
164
+ }
165
+ .tier-dot { width: 7px; height: 7px; border-radius: 50%; flex-shrink: 0; }
166
+ .tier-name { font-size: 11px; font-weight: 600; flex: 1; }
167
+ .tier-score {
168
+ font-size: 11px; font-weight: 600;
169
+ font-family: 'SF Mono', 'Fira Code', monospace;
170
+ }
171
+ .tier-chevron {
172
+ font-size: 9px; color: var(--text-muted);
173
+ transition: transform 0.2s;
174
+ }
175
+ .tier-row.open .tier-chevron { transform: rotate(180deg); }
176
+ .tier-body {
177
+ max-height: 0; overflow: hidden;
178
+ transition: max-height 0.3s ease; padding: 0 12px;
179
+ }
180
+ .tier-row.open .tier-body { max-height: 200px; padding: 4px 12px 10px; }
181
+ .tier-detail { font-size: 10px; color: var(--text-secondary); line-height: 1.6; }
182
+ .flag-badge {
183
+ display: inline-block; padding: 1px 6px;
184
+ background: rgba(239,68,68,0.12); color: var(--danger);
185
+ border-radius: 4px; font-size: 10px; margin: 2px 2px 2px 0;
186
+ }
187
+
188
+ /* ── Feedback Section ────────────────────────────────── */
189
+ .feedback-section {
190
+ padding: 0 20px 12px; margin-top: 8px;
191
+ }
192
+ .feedback-prompt {
193
+ font-size: 12px; color: var(--text-secondary);
194
+ margin-bottom: 8px; text-align: center;
195
+ }
196
+ .feedback-buttons { display: flex; gap: 8px; }
197
+ .fb-btn {
198
+ flex: 1; padding: 8px 0;
199
+ border: 1px solid var(--border); border-radius: var(--radius-sm);
200
+ background: var(--bg-card); color: var(--text-primary);
201
+ font-size: 13px; font-weight: 600; cursor: pointer;
202
+ transition: all 0.2s; font-family: inherit;
203
+ }
204
+ .fb-btn:hover { background: var(--bg-hover); }
205
+ .fb-btn-correct:hover { border-color: var(--safe); background: rgba(34,197,94,0.08); }
206
+ .fb-btn-wrong:hover { border-color: var(--danger); background: rgba(239,68,68,0.08); }
207
+ .fb-btn.selected {
208
+ opacity: 1 !important;
209
+ }
210
+ .fb-btn.dimmed {
211
+ opacity: 0.3; pointer-events: none;
212
+ }
213
+ .fb-btn-correct.selected {
214
+ border-color: var(--safe); background: rgba(34,197,94,0.15);
215
+ color: var(--safe);
216
+ }
217
+ .fb-btn-wrong.selected {
218
+ border-color: var(--danger); background: rgba(239,68,68,0.15);
219
+ color: var(--danger);
220
+ }
221
+
222
+ .thank-you {
223
+ display: none; text-align: center; padding: 10px;
224
+ font-size: 12px; color: var(--safe);
225
+ animation: slideDown 0.3s ease;
226
+ }
227
+ .thank-you.show { display: block; }
228
+ @keyframes slideDown {
229
+ from { opacity: 0; transform: translateY(-6px); }
230
+ to { opacity: 1; transform: translateY(0); }
231
+ }
232
+
233
+ /* ── Retraining Status ───────────────────────────────── */
234
+ .retrain-section {
235
+ padding: 8px 20px 12px;
236
+ border-top: 1px solid var(--border);
237
+ margin-top: 4px;
238
+ }
239
+ .retrain-row {
240
+ display: flex; align-items: center; gap: 6px;
241
+ font-size: 11px; color: var(--text-muted);
242
+ margin-bottom: 4px;
243
+ }
244
+ .retrain-row .icon { font-size: 12px; }
245
+ .retrain-progress {
246
+ height: 3px; background: var(--bg-hover);
247
+ border-radius: 2px; margin: 6px 0 4px;
248
+ overflow: hidden;
249
+ }
250
+ .retrain-progress-bar {
251
+ height: 100%; background: linear-gradient(90deg, var(--accent), #7C6BDB);
252
+ border-radius: 2px; transition: width 0.5s ease;
253
+ }
254
+
255
+ /* ── Session Stats ───────────────────────────────────── */
256
+ .stats-row {
257
+ display: flex; justify-content: space-between;
258
+ padding: 6px 20px;
259
+ border-top: 1px solid var(--border);
260
+ font-size: 11px; color: var(--text-muted);
261
+ }
262
+
263
+ /* ── Blocked Overlay ─────────────────────────────────── */
264
+ .blocked-overlay {
265
+ display: none; padding: 28px 24px; text-align: center;
266
+ }
267
+ .blocked-overlay.show {
268
+ display: flex; flex-direction: column;
269
+ align-items: center; gap: 10px;
270
+ }
271
+ .blocked-shield { font-size: 48px; animation: shieldPop 0.6s ease; }
272
+ .blocked-title { font-size: 18px; font-weight: 700; color: var(--danger); }
273
+ .blocked-url {
274
+ font-size: 11px; color: var(--text-muted);
275
+ word-break: break-all; max-width: 300px;
276
+ }
277
+ .blocked-method { font-size: 12px; color: var(--text-secondary); }
278
+ .proceed-btn {
279
+ margin-top: 8px; padding: 8px 20px;
280
+ background: transparent; border: 1px solid rgba(239,68,68,0.3);
281
+ border-radius: var(--radius-sm); color: var(--text-secondary);
282
+ font-size: 12px; cursor: pointer; font-family: inherit;
283
+ transition: all 0.2s;
284
+ }
285
+ .proceed-btn:hover {
286
+ background: rgba(239,68,68,0.08);
287
+ border-color: var(--danger); color: var(--danger);
288
+ }
289
+
290
+ .offline-banner {
291
+ display: none; padding: 8px 16px;
292
+ background: rgba(245, 158, 11, 0.08);
293
+ border: 1px solid rgba(245, 158, 11, 0.2);
294
+ border-radius: var(--radius-sm);
295
+ margin: 8px 20px 0; font-size: 11px;
296
+ color: var(--warning); text-align: center;
297
+ }
298
+ .offline-banner.show { display: block; }
299
+ </style>
300
+ </head>
301
+ <body>
302
+
303
+ <!-- Header -->
304
+ <div class="header">
305
+ <div class="header-logo">πŸ›‘οΈ</div>
306
+ <h1>Phish<span>Guard</span> AI</h1>
307
+ <span class="header-badge" id="versionBadge">v3.0</span>
308
+ </div>
309
+
310
+ <!-- URL Bar -->
311
+ <div class="url-bar">
312
+ <div class="url-text" id="currentUrl">Analyzing...</div>
313
+ </div>
314
+
315
+ <!-- Server offline banner -->
316
+ <div class="offline-banner" id="offlineBanner">
317
+ ⚠️ Server offline β€” local heuristic only
318
+ </div>
319
+
320
+ <!-- Loading State -->
321
+ <div class="loading-container" id="loadingState">
322
+ <div class="spinner"></div>
323
+ <div class="loading-text">Analyzing with AI ensemble...</div>
324
+ </div>
325
+
326
+ <!-- Result Panel -->
327
+ <div class="result-panel" id="resultPanel" style="display:none;">
328
+ <div class="result-hero">
329
+ <div class="score-ring-wrap">
330
+ <svg width="80" height="80" viewBox="0 0 80 80">
331
+ <circle class="score-ring-bg" cx="40" cy="40" r="34" />
332
+ <circle class="score-ring-fg" id="scoreRing" cx="40" cy="40" r="34" />
333
+ </svg>
334
+ <div class="score-label">
335
+ <div class="score-pct" id="scorePct">0%</div>
336
+ <div class="score-sub" id="scoreSub">RISK</div>
337
+ </div>
338
+ </div>
339
+ <div style="display:flex; flex-direction:column; align-items:center; gap:4px">
340
+ <div class="shield-icon" id="shieldIcon">πŸ›‘οΈ</div>
341
+ </div>
342
+ <div class="result-verdict">
343
+ <div class="verdict-label" id="verdictLabel">Analyzing</div>
344
+ <div class="verdict-detail" id="verdictDetail">Please wait...</div>
345
+ </div>
346
+ </div>
347
+
348
+ <!-- Tier Rows -->
349
+ <div class="tier-section" id="tierSection">
350
+ <div class="tier-row" data-tier="1">
351
+ <div class="tier-header" onclick="toggleTier(this)">
352
+ <div class="tier-dot" id="t1Dot"></div>
353
+ <div class="tier-name">Tier 1 Β· Whitelist</div>
354
+ <div class="tier-score" id="t1Score">β€”</div>
355
+ <div class="tier-chevron">β–Ό</div>
356
+ </div>
357
+ <div class="tier-body"><div class="tier-detail" id="t1Detail">O(1) domain lookup</div></div>
358
+ </div>
359
+ <div class="tier-row" data-tier="2">
360
+ <div class="tier-header" onclick="toggleTier(this)">
361
+ <div class="tier-dot" id="t2Dot"></div>
362
+ <div class="tier-name">Tier 2 Β· Heuristics</div>
363
+ <div class="tier-score" id="t2Score">β€”</div>
364
+ <div class="tier-chevron">β–Ό</div>
365
+ </div>
366
+ <div class="tier-body"><div class="tier-detail" id="t2Detail">15 regex/math signals</div></div>
367
+ </div>
368
+ <div class="tier-row" data-tier="3">
369
+ <div class="tier-header" onclick="toggleTier(this)">
370
+ <div class="tier-dot" id="t3Dot"></div>
371
+ <div class="tier-name">Tier 3 Β· BERT + GNN</div>
372
+ <div class="tier-score" id="t3Score">β€”</div>
373
+ <div class="tier-chevron">β–Ό</div>
374
+ </div>
375
+ <div class="tier-body"><div class="tier-detail" id="t3Detail">Parallel NLP + graph analysis</div></div>
376
+ </div>
377
+ <div class="tier-row" data-tier="4">
378
+ <div class="tier-header" onclick="toggleTier(this)">
379
+ <div class="tier-dot" id="t4Dot"></div>
380
+ <div class="tier-name">Tier 4 Β· CNN Visual</div>
381
+ <div class="tier-score" id="t4Score">β€”</div>
382
+ <div class="tier-chevron">β–Ό</div>
383
+ </div>
384
+ <div class="tier-body"><div class="tier-detail" id="t4Detail">Screenshot + brand detection</div></div>
385
+ </div>
386
+ </div>
387
+ </div>
388
+
389
+ <!-- Feedback Section -->
390
+ <div class="feedback-section" id="feedbackSection" style="display:none;">
391
+ <div class="feedback-prompt">Was this correct?</div>
392
+ <div class="feedback-buttons">
393
+ <button class="fb-btn fb-btn-correct" id="btnCorrect">πŸ‘ Correct</button>
394
+ <button class="fb-btn fb-btn-wrong" id="btnWrong">οΏ½οΏ½οΏ½ Incorrect</button>
395
+ </div>
396
+ <div class="thank-you" id="thankYou">βœ“ Thanks! Helps us improve 🎯</div>
397
+ </div>
398
+
399
+ <!-- Retraining Status -->
400
+ <div class="retrain-section" id="retrainSection" style="display:none;">
401
+ <div class="retrain-row">
402
+ <span class="icon">πŸ”„</span>
403
+ <span id="retrainStatus">Next retrain: calculating...</span>
404
+ </div>
405
+ <div class="retrain-progress">
406
+ <div class="retrain-progress-bar" id="retrainProgressBar" style="width: 0%"></div>
407
+ </div>
408
+ <div class="retrain-row">
409
+ <span class="icon">πŸ“ˆ</span>
410
+ <span id="retrainLast">No retraining yet</span>
411
+ </div>
412
+ </div>
413
+
414
+ <!-- Session Stats -->
415
+ <div class="stats-row" id="statsRow" style="display:none;">
416
+ <span id="statScanned">πŸ“Š 0 scanned</span>
417
+ <span id="statFeedback">πŸ’¬ 0 feedback</span>
418
+ <span id="statVersion">🏷️ v0</span>
419
+ </div>
420
+
421
+ <!-- Blocked Page Overlay -->
422
+ <div class="blocked-overlay" id="blockedOverlay">
423
+ <div class="blocked-shield">🚨</div>
424
+ <div class="blocked-title">Phishing Detected!</div>
425
+ <div class="blocked-url" id="blockedUrl"></div>
426
+ <div class="blocked-method" id="blockedMethod"></div>
427
+ <button class="proceed-btn" id="proceedBtn">Proceed Anyway (Unsafe)</button>
428
+ </div>
429
+
430
+ <script src="popup.js"></script>
431
+ </body>
432
+ </html>
popup.js ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // ============================================================
2
+ // PhishGuard AI - popup.js
3
+ // Popup logic: displays verdict, feedback buttons, retraining
4
+ // status, and session stats.
5
+ // ============================================================
6
+
7
+ (function() {
8
+ "use strict";
9
+
10
+ // ── DOM Elements ──────────────────────────────────────────────
11
+ const $id = id => document.getElementById(id);
12
+ const loadingState = $id("loadingState");
13
+ const resultPanel = $id("resultPanel");
14
+ const feedbackSection = $id("feedbackSection");
15
+ const retrainSection = $id("retrainSection");
16
+ const statsRow = $id("statsRow");
17
+ const blockedOverlay = $id("blockedOverlay");
18
+ const offlineBanner = $id("offlineBanner");
19
+
20
+ let currentResult = null;
21
+ let currentUrlHash = null;
22
+ let feedbackGiven = false;
23
+ let feedbackTimeout = null;
24
+ let countdownInterval = null;
25
+
26
+ // ── Init ──────────────────────────────────────────────────────
27
+ async function init() {
28
+ // Check if this is a blocked page redirect
29
+ const params = new URLSearchParams(window.location.search);
30
+ if (params.get("blocked") === "1") {
31
+ showBlockedPage(params);
32
+ return;
33
+ }
34
+
35
+ // Get active tab
36
+ const [tab] = await chrome.tabs.query({ active: true, currentWindow: true });
37
+ if (!tab?.url || !tab.url.startsWith("http")) {
38
+ showResult({
39
+ status: "safe", tier: 0, method: "internal",
40
+ confidence: 0, url: tab?.url || "N/A"
41
+ });
42
+ return;
43
+ }
44
+
45
+ $id("currentUrl").textContent = tab.url;
46
+
47
+ // Try per-tab cache first (instant)
48
+ chrome.runtime.sendMessage(
49
+ { type: "get_tab_result", tabId: tab.id },
50
+ response => {
51
+ if (response?.result) {
52
+ showResult(response.result);
53
+ } else {
54
+ // Fallback to chrome.storage
55
+ chrome.storage.local.get("lastResult", data => {
56
+ if (data.lastResult && data.lastResult.url === tab.url) {
57
+ showResult(data.lastResult);
58
+ } else {
59
+ loadingState.style.display = "flex";
60
+ }
61
+ });
62
+ }
63
+ }
64
+ );
65
+
66
+ // Load status
67
+ loadStatus();
68
+ }
69
+
70
+ // ── Show Result ───────────────────────────────────────────────
71
+ function showResult(result) {
72
+ currentResult = result;
73
+ loadingState.style.display = "none";
74
+ resultPanel.style.display = "block";
75
+ feedbackSection.style.display = "block";
76
+ retrainSection.style.display = "block";
77
+ statsRow.style.display = "flex";
78
+
79
+ const isBlocked = result.status === "blocked" || result.is_phishing;
80
+ const isWarn = !isBlocked && (result.confidence || 0) >= 0.4;
81
+ const confidence = Math.round((result.confidence || 0) * 100);
82
+ const tier = result.tier || 0;
83
+
84
+ // Score ring
85
+ const ring = $id("scoreRing");
86
+ const circumference = 2 * Math.PI * 34; // r=34
87
+ const offset = circumference - (confidence / 100) * circumference;
88
+ ring.style.strokeDasharray = circumference;
89
+
90
+ setTimeout(() => {
91
+ ring.style.strokeDashoffset = offset;
92
+ ring.style.stroke = isBlocked ? "var(--danger)" :
93
+ isWarn ? "var(--warning)" : "var(--safe)";
94
+ }, 100);
95
+
96
+ $id("scorePct").textContent = confidence + "%";
97
+ $id("scorePct").className = `score-pct ${isBlocked ? "status-danger" : isWarn ? "status-warn" : "status-safe"}`;
98
+ $id("scoreSub").textContent = isBlocked ? "THREAT" : "RISK";
99
+
100
+ // Shield
101
+ $id("shieldIcon").textContent = isBlocked ? "🚨" : isWarn ? "⚠️" : "βœ…";
102
+
103
+ // Verdict text
104
+ $id("verdictLabel").textContent = isBlocked ? "PHISHING DETECTED" :
105
+ isWarn ? "SUSPICIOUS" : "SAFE";
106
+ $id("verdictLabel").className = `verdict-label ${isBlocked ? "status-danger" : isWarn ? "status-warn" : "status-safe"}`;
107
+
108
+ const methodNames = {
109
+ "whitelist": "Whitelist (Tier 1)",
110
+ "heuristic": "Heuristic Engine (Tier 2)",
111
+ "heuristic-fallback": "Heuristic Fallback",
112
+ "bert_gnn_ensemble": "BERT + GNN Ensemble (Tier 3)",
113
+ "full_ensemble_bert_gnn_cnn": "Full ML Ensemble (Tier 4)",
114
+ "ensemble_with_visual": "Ensemble + Visual (Tier 4)",
115
+ "user-override": "User Override",
116
+ };
117
+ const methodText = methodNames[result.method] || result.method || "Unknown";
118
+ $id("verdictDetail").textContent = `${methodText} Β· Confidence: ${confidence}%`;
119
+
120
+ // Tier dots and scores
121
+ updateTierRow(1, tier >= 1 ? "checked" : "pending", tier === 1 ? "SAFE βœ“" : "Miss β†’");
122
+ updateTierRow(2, tier >= 2 ? (isBlocked && tier === 2 ? "blocked" : "checked") : "pending",
123
+ result.heuristic_score != null ? `${result.heuristic_score}/100` : "β€”");
124
+ updateTierRow(3, tier >= 3 ? (isBlocked && tier === 3 ? "blocked" : "checked") : "pending",
125
+ result.details?.tier3_score != null ? (result.details.tier3_score * 100).toFixed(0) + "%" : "β€”");
126
+ updateTierRow(4, tier >= 4 ? (isBlocked && tier === 4 ? "blocked" : "checked") : "pending",
127
+ result.details?.tier4_score != null ? (result.details.tier4_score * 100).toFixed(0) + "%" : "β€”");
128
+
129
+ // Tier 2 details β€” show triggered signals
130
+ if (result.signals && result.signals.length > 0) {
131
+ const badges = result.signals.map(s => `<span class="flag-badge">${s}</span>`).join(" ");
132
+ $id("t2Detail").innerHTML = `Signals triggered:<br>${badges}`;
133
+ }
134
+
135
+ // Compute URL hash for feedback
136
+ computeUrlHash(result.url);
137
+
138
+ // Check if we already gave feedback
139
+ checkExistingFeedback();
140
+ }
141
+
142
+ function updateTierRow(tier, status, scoreText) {
143
+ const dot = $id(`t${tier}Dot`);
144
+ const score = $id(`t${tier}Score`);
145
+
146
+ const colors = {
147
+ checked: "var(--safe)",
148
+ blocked: "var(--danger)",
149
+ pending: "var(--text-muted)",
150
+ };
151
+ dot.style.background = colors[status] || colors.pending;
152
+ score.textContent = scoreText;
153
+ }
154
+
155
+ // ── Blocked Page ──────────────────────────────────────────────
156
+ function showBlockedPage(params) {
157
+ loadingState.style.display = "none";
158
+ blockedOverlay.classList.add("show");
159
+
160
+ const url = decodeURIComponent(params.get("url") || "");
161
+ const score = params.get("score") || "0";
162
+ const method = decodeURIComponent(params.get("method") || "");
163
+
164
+ $id("blockedUrl").textContent = url;
165
+ $id("blockedMethod").textContent = `Detection: ${method} Β· Risk: ${score}%`;
166
+ $id("currentUrl").textContent = url;
167
+
168
+ currentResult = { url, status: "blocked", confidence: parseInt(score) / 100, method };
169
+ computeUrlHash(url);
170
+
171
+ // Show feedback for blocked pages too
172
+ feedbackSection.style.display = "block";
173
+ retrainSection.style.display = "block";
174
+ statsRow.style.display = "flex";
175
+ loadStatus();
176
+
177
+ $id("proceedBtn").onclick = () => {
178
+ chrome.runtime.sendMessage({ type: "whitelist_url", url }, () => {
179
+ chrome.tabs.update({ url });
180
+ });
181
+ };
182
+ }
183
+
184
+ // ── Feedback ──────────────────────────────────────────────────
185
+ async function computeUrlHash(url) {
186
+ if (!url) return;
187
+ const encoded = new TextEncoder().encode(url);
188
+ const hash = await crypto.subtle.digest("SHA-256", encoded);
189
+ currentUrlHash = Array.from(new Uint8Array(hash))
190
+ .map(b => b.toString(16).padStart(2, "0")).join("");
191
+ }
192
+
193
+ function submitFeedback(feedback) {
194
+ if (feedbackGiven || !currentUrlHash) return;
195
+ feedbackGiven = true;
196
+
197
+ chrome.runtime.sendMessage({
198
+ type: "submit_feedback",
199
+ url_hash: currentUrlHash,
200
+ feedback: feedback,
201
+ }, response => {
202
+ if (response?.success) {
203
+ // Highlight selected, dim other
204
+ const correctBtn = $id("btnCorrect");
205
+ const wrongBtn = $id("btnWrong");
206
+
207
+ if (feedback === "correct") {
208
+ correctBtn.classList.add("selected");
209
+ wrongBtn.classList.add("dimmed");
210
+ } else {
211
+ wrongBtn.classList.add("selected");
212
+ correctBtn.classList.add("dimmed");
213
+ }
214
+
215
+ $id("thankYou").classList.add("show");
216
+
217
+ // Allow changing within 5 minutes
218
+ feedbackTimeout = setTimeout(() => {
219
+ feedbackGiven = false;
220
+ correctBtn.classList.remove("selected", "dimmed");
221
+ wrongBtn.classList.remove("selected", "dimmed");
222
+ }, 5 * 60 * 1000);
223
+ }
224
+ });
225
+ }
226
+
227
+ function checkExistingFeedback() {
228
+ // Check if feedback was already given for this URL
229
+ chrome.storage.local.get("phishguard_feedback_queue", data => {
230
+ const queue = data.phishguard_feedback_queue || [];
231
+ const record = queue.find(r => r.url_hash === currentUrlHash);
232
+ if (record?.user_feedback) {
233
+ feedbackGiven = true;
234
+ const correctBtn = $id("btnCorrect");
235
+ const wrongBtn = $id("btnWrong");
236
+ if (record.user_feedback === "correct") {
237
+ correctBtn.classList.add("selected");
238
+ wrongBtn.classList.add("dimmed");
239
+ } else {
240
+ wrongBtn.classList.add("selected");
241
+ correctBtn.classList.add("dimmed");
242
+ }
243
+ }
244
+ });
245
+ }
246
+
247
+ // ── Retraining Status ─────────────────────────────────────────
248
+ function loadStatus() {
249
+ chrome.runtime.sendMessage({ type: "get_status" }, status => {
250
+ if (!status) return;
251
+
252
+ // Stats row
253
+ $id("statScanned").textContent = `πŸ“Š ${status.scan_count} scanned`;
254
+ $id("statFeedback").textContent = `πŸ’¬ ${status.labeled_count || 0} labeled`;
255
+ $id("statVersion").textContent = `🏷️ v${status.model_version}`;
256
+ $id("versionBadge").textContent = `v${status.model_version || "3.0"}`;
257
+
258
+ // Retrain progress
259
+ const urlsRemaining = status.next_retrain_urls_remaining || 50;
260
+ const progress = Math.round(((50 - urlsRemaining) / 50) * 100);
261
+ $id("retrainProgressBar").style.width = `${progress}%`;
262
+
263
+ // Retrain status text
264
+ const timeMs = status.next_retrain_time_remaining_ms || 0;
265
+ const hours = Math.floor(timeMs / 3600000);
266
+ const mins = Math.floor((timeMs % 3600000) / 60000);
267
+
268
+ const labeledNeeded = status.min_labeled_needed || 0;
269
+ if (labeledNeeded > 0) {
270
+ $id("retrainStatus").textContent =
271
+ `Need ${labeledNeeded} more feedback to retrain`;
272
+ } else {
273
+ $id("retrainStatus").textContent =
274
+ `Next retrain: ${urlsRemaining} URLs or ${hours}h ${mins}m`;
275
+ }
276
+
277
+ // Last retrain info
278
+ if (status.last_retrain_ts) {
279
+ const ago = timeSince(new Date(status.last_retrain_ts));
280
+ $id("retrainLast").textContent = `Last retrain: ${ago} ago`;
281
+ }
282
+
283
+ // Start countdown
284
+ startCountdown(timeMs);
285
+ });
286
+ }
287
+
288
+ function startCountdown(initialMs) {
289
+ if (countdownInterval) clearInterval(countdownInterval);
290
+ let remaining = initialMs;
291
+
292
+ countdownInterval = setInterval(() => {
293
+ remaining -= 1000;
294
+ if (remaining <= 0) {
295
+ clearInterval(countdownInterval);
296
+ $id("retrainStatus").textContent = "Retrain pending...";
297
+ return;
298
+ }
299
+ const h = Math.floor(remaining / 3600000);
300
+ const m = Math.floor((remaining % 3600000) / 60000);
301
+ const s = Math.floor((remaining % 60000) / 1000);
302
+
303
+ // Only update the time portion if it's the time display
304
+ const el = $id("retrainStatus");
305
+ if (el.textContent.includes("URLs or")) {
306
+ const parts = el.textContent.split(" or ");
307
+ el.textContent = `${parts[0]} or ${h}h ${m}m ${s}s`;
308
+ }
309
+ }, 1000);
310
+ }
311
+
312
+ function timeSince(date) {
313
+ const seconds = Math.floor((Date.now() - date.getTime()) / 1000);
314
+ if (seconds < 60) return `${seconds}s`;
315
+ if (seconds < 3600) return `${Math.floor(seconds / 60)}m`;
316
+ if (seconds < 86400) return `${Math.floor(seconds / 3600)}h`;
317
+ return `${Math.floor(seconds / 86400)}d`;
318
+ }
319
+
320
+ // ── Tier Row Toggle ───────────────────────────────────────────
321
+ window.toggleTier = function(header) {
322
+ const row = header.parentElement;
323
+ row.classList.toggle("open");
324
+ };
325
+
326
+ // ── Event Listeners ───────────────────────────────────────────
327
+ $id("btnCorrect").addEventListener("click", () => submitFeedback("correct"));
328
+ $id("btnWrong").addEventListener("click", () => submitFeedback("incorrect"));
329
+
330
+ // ── Start ─────────────────────────────────────────────────────
331
+ init();
332
+ })();
render.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # render.yaml β€” Render.com deployment config (architecture doc 5.2)
3
+ #
4
+ # PLAYWRIGHT ON RENDER:
5
+ # To enable Tier 4 visual analysis (Playwright + Chromium), you need:
6
+ #
7
+ # 1. Add ENABLE_VISUAL_TIER=1 env var below
8
+ # 2. Switch from python:3.10-slim to a full image in Dockerfile
9
+ # OR add Chromium system deps to the Dockerfile:
10
+ # apt-get install -y libnss3 libatk1.0-0 libatk-bridge2.0-0 \
11
+ # libcups2 libxkbcommon0 libgbm1 libpango-1.0-0 \
12
+ # libcairo2 libasound2 libxdamage1 libxrandr2 libxfixes3
13
+ # 3. Add to Dockerfile after pip install:
14
+ # RUN pip install playwright && playwright install chromium
15
+ #
16
+ # NOTE: Playwright + Chromium adds ~400MB to the Docker image.
17
+ # On the free tier (512MB RAM), this may cause OOM.
18
+ # Only enable if you have a paid plan with >= 1GB RAM.
19
+ # ============================================================
20
+
21
+ services:
22
+ - type: web
23
+ name: phishguard-api
24
+ runtime: docker
25
+ dockerfilePath: ./Dockerfile
26
+ plan: free
27
+ healthCheckPath: /health
28
+ autoDeploy: true
29
+ envVars:
30
+ - key: PORT
31
+ value: "8000"
32
+ # Uncomment below to enable Tier 4 visual analysis (needs Playwright in Dockerfile)
33
+ # - key: ENABLE_VISUAL_TIER
34
+ # value: "1"
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.0
2
+ uvicorn[standard]==0.29.0
3
+ transformers==4.40.0
4
+ torch==2.2.2
5
+ torch-geometric==2.5.2
6
+ torchvision==0.17.2
7
+ playwright==1.44.0
8
+ pillow==10.3.0
9
+ scikit-learn==1.4.2
10
+ pandas==2.2.2
11
+ numpy==1.26.4
12
+ httpx==0.27.0
13
+ imagehash==4.3.1
14
+ requests==2.31.0
15
+ aiohttp==3.9.5
16
+ aiofiles==23.2.1
17
+ python-multipart==0.0.9
18
+ apscheduler==3.10.4
retraining_service.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - retraining_service.py
3
+ # Incremental retraining service for all 3 ML models.
4
+ #
5
+ # Receives labeled feedback samples from the Chrome extension.
6
+ # Runs parallel incremental updates for BERT, GNN, and CNN.
7
+ # Tracks model version and accuracy deltas.
8
+ # Supports hot-reload of all models without server restart.
9
+ # ============================================================
10
+
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import logging
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple
20
+
21
+ logger = logging.getLogger("phishguard.retrain")
22
+
23
+ DATA_DIR = Path(__file__).parent / "data"
24
+ MODEL_VERSION_PATH = DATA_DIR / "model_version.json"
25
+
26
+
27
+ @dataclass
28
+ class FeedbackRecord:
29
+ """A single feedback record from the Chrome extension."""
30
+ url: str
31
+ verdict: str # "phishing" or "safe"
32
+ confidence: float = 0.0
33
+ tier_used: int = 0
34
+ heuristic_score: int = 0
35
+ signals: List[str] = field(default_factory=list)
36
+ user_feedback: Optional[str] = None # "correct" or "incorrect"
37
+ timestamp: str = ""
38
+ feedback_ts: Optional[str] = None
39
+ url_hash: str = ""
40
+ session_id: str = ""
41
+
42
+
43
+ @dataclass
44
+ class RetrainResult:
45
+ """Result from a retraining run."""
46
+ status: str # "success", "skipped", "error"
47
+ models_updated: List[str] = field(default_factory=list)
48
+ samples_used: int = 0
49
+ duration_seconds: float = 0.0
50
+ accuracy_delta: Dict[str, Optional[float]] = field(default_factory=dict)
51
+ next_retrain_hint: Dict = field(default_factory=dict)
52
+
53
+
54
+ class RetrainingService:
55
+ """
56
+ Orchestrates incremental retraining for all 3 ML models.
57
+ Called by POST /retrain endpoint.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ bert_classifier,
63
+ gnn_inference,
64
+ cnn_inference,
65
+ ) -> None:
66
+ self._bert = bert_classifier
67
+ self._gnn = gnn_inference
68
+ self._cnn = cnn_inference
69
+ self._model_version = self._load_version()
70
+
71
+ def _load_version(self) -> int:
72
+ """Load current model version from disk."""
73
+ MODEL_VERSION_PATH.parent.mkdir(parents=True, exist_ok=True)
74
+ if MODEL_VERSION_PATH.exists():
75
+ try:
76
+ data = json.loads(MODEL_VERSION_PATH.read_text())
77
+ return data.get("version", 0)
78
+ except Exception:
79
+ pass
80
+ return 0
81
+
82
+ def _save_version(self, accuracy_delta: Dict[str, Optional[float]]) -> None:
83
+ """Save updated model version to disk."""
84
+ self._model_version += 1
85
+ data = {
86
+ "version": self._model_version,
87
+ "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
88
+ "accuracy": accuracy_delta,
89
+ }
90
+ MODEL_VERSION_PATH.write_text(json.dumps(data, indent=2))
91
+
92
+ @property
93
+ def model_version(self) -> int:
94
+ return self._model_version
95
+
96
+ def get_version_info(self) -> dict:
97
+ """Get current model version info for GET /model_version."""
98
+ if MODEL_VERSION_PATH.exists():
99
+ try:
100
+ return json.loads(MODEL_VERSION_PATH.read_text())
101
+ except Exception:
102
+ pass
103
+ return {
104
+ "version": self._model_version,
105
+ "updated_at": None,
106
+ "accuracy": {},
107
+ }
108
+
109
+ async def retrain(
110
+ self,
111
+ samples: List[FeedbackRecord],
112
+ ) -> RetrainResult:
113
+ """
114
+ Perform incremental retraining on all models.
115
+
116
+ Steps:
117
+ 1. Validate samples (min 10, URL format check)
118
+ 2. Separate by tier_used for targeted updates
119
+ 3. Run BERT + GNN updates in parallel
120
+ 4. Run CNN update if Tier 4 samples exist
121
+ 5. Compute accuracy_delta for each model
122
+ 6. Increment model version
123
+ 7. Hot-reload all models
124
+
125
+ Returns RetrainResult with status and deltas.
126
+ """
127
+ start_time = time.time()
128
+
129
+ # 1. Validate
130
+ valid_samples = self._validate_samples(samples)
131
+ if len(valid_samples) < 10:
132
+ return RetrainResult(
133
+ status="skipped",
134
+ samples_used=len(valid_samples),
135
+ next_retrain_hint={
136
+ "recommended_trigger": "count",
137
+ "min_samples_needed": 10 - len(valid_samples),
138
+ },
139
+ )
140
+
141
+ # 2. Convert to (url, label) pairs
142
+ url_label_pairs: List[Tuple[str, int]] = []
143
+ tier4_pairs: List[Tuple[str, int]] = []
144
+
145
+ for sample in valid_samples:
146
+ # Determine the true label based on user feedback
147
+ if sample.user_feedback == "correct":
148
+ label = 1 if sample.verdict == "phishing" else 0
149
+ elif sample.user_feedback == "incorrect":
150
+ label = 0 if sample.verdict == "phishing" else 1
151
+ else:
152
+ continue
153
+
154
+ url_label_pairs.append((sample.url, label))
155
+ if sample.tier_used == 4:
156
+ tier4_pairs.append((sample.url, label))
157
+
158
+ if len(url_label_pairs) < 5:
159
+ return RetrainResult(
160
+ status="skipped",
161
+ samples_used=len(url_label_pairs),
162
+ next_retrain_hint={
163
+ "recommended_trigger": "count",
164
+ "min_samples_needed": 5,
165
+ },
166
+ )
167
+
168
+ # 3. Run updates
169
+ models_updated: List[str] = []
170
+ accuracy_delta: Dict[str, Optional[float]] = {}
171
+
172
+ try:
173
+ # BERT + GNN in parallel
174
+ loop = asyncio.get_event_loop()
175
+
176
+ bert_task = loop.run_in_executor(
177
+ None,
178
+ self._bert.incremental_update,
179
+ url_label_pairs,
180
+ )
181
+
182
+ gnn_task = loop.run_in_executor(
183
+ None,
184
+ self._gnn.incremental_update,
185
+ url_label_pairs,
186
+ )
187
+
188
+ bert_delta, gnn_delta = await asyncio.gather(
189
+ bert_task, gnn_task,
190
+ return_exceptions=True,
191
+ )
192
+
193
+ # Process BERT result
194
+ if isinstance(bert_delta, Exception):
195
+ logger.error(f"BERT update error: {bert_delta}")
196
+ accuracy_delta["bert"] = None
197
+ elif bert_delta is not None:
198
+ accuracy_delta["bert"] = bert_delta
199
+ models_updated.append("bert")
200
+ else:
201
+ accuracy_delta["bert"] = None
202
+
203
+ # Process GNN result
204
+ if isinstance(gnn_delta, Exception):
205
+ logger.error(f"GNN update error: {gnn_delta}")
206
+ accuracy_delta["gnn"] = None
207
+ elif gnn_delta is not None:
208
+ accuracy_delta["gnn"] = gnn_delta
209
+ models_updated.append("gnn")
210
+ else:
211
+ accuracy_delta["gnn"] = None
212
+
213
+ # 4. CNN update (only if Tier 4 samples exist)
214
+ if tier4_pairs:
215
+ try:
216
+ cnn_delta = await self._cnn.incremental_update(tier4_pairs)
217
+ if cnn_delta is not None:
218
+ accuracy_delta["cnn"] = cnn_delta
219
+ models_updated.append("cnn")
220
+ else:
221
+ accuracy_delta["cnn"] = None
222
+ except Exception as e:
223
+ logger.error(f"CNN update error: {e}")
224
+ accuracy_delta["cnn"] = None
225
+ else:
226
+ accuracy_delta["cnn"] = None
227
+
228
+ # 5. Update version
229
+ if models_updated:
230
+ self._save_version(accuracy_delta)
231
+
232
+ # 6. Hot-reload
233
+ await self._hot_reload(models_updated)
234
+
235
+ duration = time.time() - start_time
236
+
237
+ return RetrainResult(
238
+ status="success" if models_updated else "skipped",
239
+ models_updated=models_updated,
240
+ samples_used=len(url_label_pairs),
241
+ duration_seconds=round(duration, 2),
242
+ accuracy_delta=accuracy_delta,
243
+ next_retrain_hint={
244
+ "recommended_trigger": "count",
245
+ "min_samples_needed": 10,
246
+ },
247
+ )
248
+
249
+ except Exception as e:
250
+ logger.error(f"Retraining failed: {e}")
251
+ return RetrainResult(
252
+ status="error",
253
+ duration_seconds=round(time.time() - start_time, 2),
254
+ accuracy_delta=accuracy_delta,
255
+ )
256
+
257
+ def _validate_samples(self, samples: List[FeedbackRecord]) -> List[FeedbackRecord]:
258
+ """Validate and filter feedback samples."""
259
+ valid = []
260
+ for s in samples:
261
+ # Must have user feedback
262
+ if not s.user_feedback:
263
+ continue
264
+ if s.user_feedback not in ("correct", "incorrect"):
265
+ continue
266
+ # Must have a valid URL
267
+ if not s.url or not s.url.startswith(("http://", "https://")):
268
+ continue
269
+ valid.append(s)
270
+ return valid
271
+
272
+ async def _hot_reload(self, models: List[str]) -> None:
273
+ """Hot-reload updated models in-memory."""
274
+ if "bert" in models:
275
+ try:
276
+ bert_weights = Path(__file__).parent / "bert_weights"
277
+ if bert_weights.exists():
278
+ self._bert.load_local(bert_weights)
279
+ logger.info("BERT hot-reloaded")
280
+ except Exception as e:
281
+ logger.error(f"BERT hot-reload failed: {e}")
282
+
283
+ if "gnn" in models:
284
+ try:
285
+ self._gnn.reload()
286
+ logger.info("GNN hot-reloaded")
287
+ except Exception as e:
288
+ logger.error(f"GNN hot-reload failed: {e}")
289
+
290
+ if "cnn" in models:
291
+ try:
292
+ self._cnn.reload()
293
+ logger.info("CNN hot-reloaded")
294
+ except Exception as e:
295
+ logger.error(f"CNN hot-reload failed: {e}")
screenshot_collector.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - screenshot_collector.py
3
+ # Batch screenshot capture for CNN training data generation.
4
+ #
5
+ # Uses Playwright async API with 10 concurrent captures.
6
+ # Blocks fonts, media, video for 60-70% speedup.
7
+ # Saves PNG named by URL SHA256 hash.
8
+ # ============================================================
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import hashlib
14
+ import logging
15
+ import sys
16
+ from pathlib import Path
17
+ from typing import List
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format="%(asctime)s | %(levelname)-7s | %(message)s",
22
+ )
23
+ logger = logging.getLogger("phishguard.screenshot_collector")
24
+
25
+ BACKEND_DIR = Path(__file__).parent
26
+ DATA_DIR = BACKEND_DIR / "data"
27
+ SCREENSHOTS_DIR = DATA_DIR / "screenshots"
28
+
29
+
30
+ def url_to_filename(url: str) -> str:
31
+ """Convert URL to safe filename using SHA256 hash."""
32
+ url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
33
+ return f"{url_hash}.png"
34
+
35
+
36
+ async def capture_single(
37
+ url: str,
38
+ save_dir: Path,
39
+ semaphore: asyncio.Semaphore,
40
+ browser,
41
+ ) -> bool:
42
+ """Capture a single screenshot with concurrency limiting."""
43
+ async with semaphore:
44
+ filename = url_to_filename(url)
45
+ filepath = save_dir / filename
46
+
47
+ # Skip if already captured
48
+ if filepath.exists():
49
+ return True
50
+
51
+ try:
52
+ page = await browser.new_page(
53
+ viewport={"width": 1280, "height": 800},
54
+ user_agent=(
55
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
56
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
57
+ "Chrome/120.0.0.0 Safari/537.36"
58
+ ),
59
+ )
60
+
61
+ # Block heavy resources for speed
62
+ await page.route(
63
+ "**/*.{woff,woff2,ttf,eot,mp4,webm,ogg,avi,mp3,wav,flac}",
64
+ lambda route: route.abort(),
65
+ )
66
+
67
+ await page.goto(
68
+ url,
69
+ wait_until="domcontentloaded",
70
+ timeout=10000,
71
+ )
72
+
73
+ # Brief wait for rendering
74
+ await asyncio.sleep(0.5)
75
+
76
+ screenshot = await page.screenshot(type="png")
77
+ filepath.write_bytes(screenshot)
78
+
79
+ await page.close()
80
+ return True
81
+
82
+ except Exception as e:
83
+ logger.debug(f"Screenshot failed for {url}: {e}")
84
+ try:
85
+ await page.close()
86
+ except Exception:
87
+ pass
88
+ return False
89
+
90
+
91
+ async def batch_capture(
92
+ urls: List[str],
93
+ save_dir: Path,
94
+ concurrency: int = 10,
95
+ label: str = "urls",
96
+ ) -> int:
97
+ """
98
+ Capture screenshots for a batch of URLs concurrently.
99
+ Returns count of successful captures.
100
+ """
101
+ save_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ try:
104
+ from playwright.async_api import async_playwright
105
+ except ImportError:
106
+ logger.error("Playwright not installed. Run: pip install playwright && playwright install chromium")
107
+ return 0
108
+
109
+ semaphore = asyncio.Semaphore(concurrency)
110
+ success_count = 0
111
+
112
+ async with async_playwright() as p:
113
+ browser = await p.chromium.launch(headless=True)
114
+
115
+ tasks = [
116
+ capture_single(url, save_dir, semaphore, browser)
117
+ for url in urls
118
+ ]
119
+
120
+ results = await asyncio.gather(*tasks, return_exceptions=True)
121
+
122
+ for i, result in enumerate(results):
123
+ if result is True:
124
+ success_count += 1
125
+ if (i + 1) % 50 == 0:
126
+ logger.info(f" {label}: {i+1}/{len(urls)} processed ({success_count} captured)")
127
+
128
+ await browser.close()
129
+
130
+ logger.info(f" {label}: {success_count}/{len(urls)} screenshots captured")
131
+ return success_count
132
+
133
+
134
+ async def collect_training_screenshots(
135
+ phish_count: int = 10,
136
+ legit_count: int = 10,
137
+ ) -> None:
138
+ """Collect screenshots for CNN training."""
139
+ from data_collector import download_phishtank, download_tranco
140
+
141
+ phishing_dir = SCREENSHOTS_DIR / "phishing"
142
+ legitimate_dir = SCREENSHOTS_DIR / "legitimate"
143
+
144
+ # Download URL lists
145
+ print("πŸ“₯ Loading URL lists...")
146
+ phish_urls = download_phishtank(max_urls=phish_count)[:phish_count]
147
+ legit_urls = download_tranco(n=legit_count)[:legit_count]
148
+
149
+ print(f"\nπŸ“Έ Capturing phishing screenshots ({len(phish_urls)} URLs)...")
150
+ phish_success = await batch_capture(phish_urls, phishing_dir, label="Phishing")
151
+
152
+ print(f"\nπŸ“Έ Capturing legitimate screenshots ({len(legit_urls)} URLs)...")
153
+ legit_success = await batch_capture(legit_urls, legitimate_dir, label="Legitimate")
154
+
155
+ print(f"\nβœ… Screenshots collected:")
156
+ print(f" Phishing: {phish_success}/{len(phish_urls)}")
157
+ print(f" Legitimate: {legit_success}/{len(legit_urls)}")
158
+ print(f" Saved to: {SCREENSHOTS_DIR}")
159
+
160
+
161
+ def main() -> None:
162
+ print("=" * 60)
163
+ print("PhishGuard AI οΏ½οΏ½οΏ½ Screenshot Collection")
164
+ print("=" * 60)
165
+
166
+ asyncio.run(collect_training_screenshots())
167
+
168
+ print("=" * 60)
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
screenshot_hasher.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - cnn/screenshot_hasher.py
3
+ # Perceptual hash-based brand impersonation detector.
4
+ #
5
+ # Compares webpage screenshots against reference hashes of
6
+ # known brand login pages using imagehash.phash.
7
+ #
8
+ # brand_boost = 0.25 if hamming_distance < 10 else 0.0
9
+ # ============================================================
10
+
11
+ from __future__ import annotations
12
+
13
+ import io
14
+ import json
15
+ import logging
16
+ from pathlib import Path
17
+ from typing import Tuple, Optional, Dict, List
18
+
19
+ from PIL import Image
20
+
21
+ logger = logging.getLogger("phishguard.cnn.hasher")
22
+
23
+ # ── Try to use imagehash, fall back to custom implementation ─────────
24
+ _imagehash_available = False
25
+ try:
26
+ import imagehash
27
+ _imagehash_available = True
28
+ except ImportError:
29
+ logger.info("imagehash not installed β€” using built-in phash")
30
+
31
+ HASH_DB_PATH = Path(__file__).parent / "brand_hashes.json"
32
+
33
+
34
+ class BrandHashDetector:
35
+ """
36
+ Perceptual hash-based brand impersonation detector.
37
+ Compares screenshots against reference hashes of 10 major brands.
38
+ """
39
+
40
+ BRANDS: List[str] = [
41
+ "paypal", "google", "apple", "microsoft", "amazon",
42
+ "chase", "netflix", "facebook", "instagram", "wellsfargo",
43
+ ]
44
+
45
+ BRAND_DOMAINS: Dict[str, str] = {
46
+ "paypal": "paypal.com",
47
+ "google": "google.com",
48
+ "apple": "apple.com",
49
+ "microsoft": "microsoft.com",
50
+ "amazon": "amazon.com",
51
+ "chase": "chase.com",
52
+ "netflix": "netflix.com",
53
+ "facebook": "facebook.com",
54
+ "instagram": "instagram.com",
55
+ "wellsfargo": "wellsfargo.com",
56
+ }
57
+
58
+ def __init__(self, hash_db_path: Optional[Path] = None) -> None:
59
+ self._hash_db_path = hash_db_path or HASH_DB_PATH
60
+ self._reference_hashes: Dict[str, dict] = {}
61
+ self._load_reference_hashes()
62
+
63
+ def _load_reference_hashes(self) -> None:
64
+ """Load reference hashes from JSON database."""
65
+ if self._hash_db_path.exists():
66
+ try:
67
+ with open(self._hash_db_path) as f:
68
+ self._reference_hashes = json.load(f)
69
+ logger.info(f"Loaded {len(self._reference_hashes)} brand hashes")
70
+ except Exception as e:
71
+ logger.warning(f"Failed to load brand hashes: {e}")
72
+ self._reference_hashes = {}
73
+ else:
74
+ logger.info("No brand hash DB found β€” brand detection disabled")
75
+ self._reference_hashes = {}
76
+
77
+ def compute_hash(self, img_bytes: bytes, hash_size: int = 16) -> Optional[int]:
78
+ """
79
+ Compute perceptual hash of an image.
80
+ Uses imagehash.phash if available, otherwise custom DCT-less implementation.
81
+ """
82
+ try:
83
+ img = Image.open(io.BytesIO(img_bytes))
84
+
85
+ if _imagehash_available:
86
+ h = imagehash.phash(img, hash_size=hash_size)
87
+ return int(str(h), 16)
88
+ else:
89
+ return self._custom_phash(img, hash_size)
90
+
91
+ except Exception as e:
92
+ logger.warning(f"Hash computation failed: {e}")
93
+ return None
94
+
95
+ def _custom_phash(self, img: Image.Image, hash_size: int = 16) -> int:
96
+ """Fallback perceptual hash (mean-based, no DCT)."""
97
+ img = img.convert("L").resize((hash_size, hash_size), Image.LANCZOS)
98
+ pixels = list(img.getdata())
99
+ avg = sum(pixels) / len(pixels)
100
+ bits = "".join("1" if p > avg else "0" for p in pixels)
101
+ return int(bits, 2)
102
+
103
+ def hamming_distance(self, h1: int, h2: int) -> int:
104
+ """Count bit differences between two hashes. 0 = identical."""
105
+ return bin(h1 ^ h2).count("1")
106
+
107
+ def detect(
108
+ self,
109
+ screenshot_bytes: bytes,
110
+ url: str = "",
111
+ threshold: int = 10,
112
+ ) -> Tuple[bool, str, float]:
113
+ """
114
+ Detect brand impersonation from screenshot.
115
+
116
+ Returns:
117
+ (is_impersonation, brand_name, confidence)
118
+ is_impersonation: True if page looks like a brand but URL doesn't match
119
+ brand_name: detected brand name or ""
120
+ confidence: 0.0-1.0 similarity score
121
+ """
122
+ page_hash = self.compute_hash(screenshot_bytes)
123
+ if page_hash is None:
124
+ return False, "", 0.0
125
+
126
+ url_lower = url.lower()
127
+ best_match: Optional[str] = None
128
+ best_distance = 999
129
+ best_confidence = 0.0
130
+
131
+ for brand, entry in self._reference_hashes.items():
132
+ try:
133
+ stored_hash = int(entry["hash"])
134
+ distance = self.hamming_distance(page_hash, stored_hash)
135
+ confidence = max(0.0, 1.0 - distance / 256.0)
136
+
137
+ if distance < best_distance:
138
+ best_distance = distance
139
+ best_match = brand
140
+ best_confidence = confidence
141
+
142
+ except (ValueError, KeyError):
143
+ continue
144
+
145
+ if best_match and best_distance <= threshold:
146
+ legit_domain = self.BRAND_DOMAINS.get(best_match, f"{best_match}.com")
147
+
148
+ # Check if URL belongs to legitimate domain
149
+ if legit_domain not in url_lower:
150
+ return True, best_match, best_confidence
151
+ else:
152
+ return False, best_match, best_confidence
153
+
154
+ return False, "", 0.0
155
+
156
+ def register_brand(
157
+ self,
158
+ brand_name: str,
159
+ domain: str,
160
+ screenshot_bytes: bytes,
161
+ ) -> bool:
162
+ """Register a brand's reference screenshot hash."""
163
+ h = self.compute_hash(screenshot_bytes)
164
+ if h is None:
165
+ return False
166
+
167
+ self._reference_hashes[brand_name] = {
168
+ "domain": domain,
169
+ "hash": str(h),
170
+ }
171
+
172
+ # Save to disk
173
+ try:
174
+ with open(self._hash_db_path, "w") as f:
175
+ json.dump(self._reference_hashes, f, indent=2)
176
+ logger.info(f"Registered brand: {brand_name} ({domain})")
177
+ return True
178
+ except Exception as e:
179
+ logger.error(f"Failed to save brand hash: {e}")
180
+ return False
181
+
182
+
183
+ # ── Legacy compatibility ─────────────────────────────────────────────
184
+ _detector = BrandHashDetector()
185
+
186
+
187
+ def check_brand_impersonation(
188
+ screenshot_bytes: bytes,
189
+ url: str,
190
+ similarity_threshold: int = 10,
191
+ ) -> dict:
192
+ """Legacy wrapper for backward compatibility."""
193
+ is_impersonation, brand, confidence = _detector.detect(
194
+ screenshot_bytes, url, similarity_threshold,
195
+ )
196
+
197
+ if is_impersonation:
198
+ return {
199
+ "impersonation_detected": True,
200
+ "impersonated_brand": brand,
201
+ "legitimate_domain": _detector.BRAND_DOMAINS.get(brand, ""),
202
+ "visual_similarity": round(confidence, 3),
203
+ }
204
+ elif brand:
205
+ return {
206
+ "impersonation_detected": False,
207
+ "matched_brand": brand,
208
+ "note": "legitimate site",
209
+ }
210
+ else:
211
+ return {
212
+ "impersonation_detected": False,
213
+ "reason": "no_brand_match",
214
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
test_endpoint.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from main import analyze_email_endpoint, EmailRequest
3
+ import json
4
+
5
+ async def run_tests():
6
+ print("--- Testing Tier 1: Whitelist ---")
7
+ res1 = await analyze_email_endpoint(EmailRequest(
8
+ sender="noreply@github.com",
9
+ subject="Your receipt",
10
+ body="...",
11
+ urls=[]
12
+ ))
13
+ print(json.dumps(res1, indent=2))
14
+
15
+ print("\n--- Testing Tier 2: Text Heuristic (No URLs) ---")
16
+ res2 = await analyze_email_endpoint(EmailRequest(
17
+ sender="admin@unknown-domain.com",
18
+ subject="URGENT: Password Reset Required",
19
+ body="Please reset immediately.",
20
+ urls=[]
21
+ ))
22
+ print(json.dumps(res2, indent=2))
23
+
24
+ print("\n--- Testing Tier 2: Async URLs ---")
25
+ res3 = await analyze_email_endpoint(EmailRequest(
26
+ sender="service@paypal-update.net",
27
+ subject="Important Account Update",
28
+ body="Click the link to verify your account.",
29
+ urls=["http://chase-bank-verify-login.cx/auth", "http://google.com/"]
30
+ ))
31
+ print(json.dumps(res3, indent=2))
32
+
33
+ if __name__ == "__main__":
34
+ asyncio.run(run_tests())
tier3_bert_gnn.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - tier3_bert_gnn.py
3
+ # Tier 3: BERT + GNN Parallel Ensemble
4
+ #
5
+ # Triggered only when Tier 2 score < 80.
6
+ # BERT and GNN run in PARALLEL via asyncio.gather + run_in_executor.
7
+ #
8
+ # Ensemble formula:
9
+ # P3 = 0.45Β·P_bert + 0.35Β·P_gnn + 0.20Β·(H_score/100)
10
+ #
11
+ # Decision:
12
+ # P3 >= 0.85 β†’ BLOCK
13
+ # P3 < 0.40 β†’ SAFE
14
+ # 0.40 <= P3 < 0.85 β†’ escalate to Tier 4
15
+ # ============================================================
16
+
17
+ from __future__ import annotations
18
+
19
+ import asyncio
20
+ import logging
21
+ from typing import Optional
22
+
23
+ logger = logging.getLogger("phishguard.tier3")
24
+
25
+
26
+ class Tier3Ensemble:
27
+ """
28
+ Tier 3: BERT + GNN parallel ensemble classifier.
29
+
30
+ Runs BERT and GNN inference in parallel using asyncio.gather
31
+ with run_in_executor for non-blocking thread pool execution.
32
+ """
33
+
34
+ # Ensemble weights
35
+ W_BERT: float = 0.45
36
+ W_GNN: float = 0.35
37
+ W_HEURISTIC: float = 0.20
38
+
39
+ def __init__(
40
+ self,
41
+ bert_classifier,
42
+ gnn_inference,
43
+ ) -> None:
44
+ self._bert = bert_classifier
45
+ self._gnn = gnn_inference
46
+
47
+ async def predict(
48
+ self,
49
+ url: str,
50
+ title: str = "",
51
+ snippet: str = "",
52
+ h_score: int = 0,
53
+ ) -> float:
54
+ """
55
+ Run BERT + GNN in parallel and compute ensemble score.
56
+
57
+ Args:
58
+ url: The URL to analyze
59
+ title: Page title (optional)
60
+ snippet: Page content snippet (optional)
61
+ h_score: Heuristic score from Tier 2 (0-100, passed through, NOT recomputed)
62
+
63
+ Returns:
64
+ P3 ∈ [0,1] β€” ensemble phishing probability
65
+ """
66
+ loop = asyncio.get_event_loop()
67
+
68
+ # Run BERT and GNN in parallel (both are CPU-bound, use thread pool)
69
+ bert_task = self._bert_predict(url, title, snippet, loop)
70
+ gnn_task = self._gnn_predict(url, loop)
71
+
72
+ p_bert, p_gnn = await asyncio.gather(bert_task, gnn_task)
73
+
74
+ # Ensemble: P3 = 0.45Β·P_bert + 0.35Β·P_gnn + 0.20Β·H_norm
75
+ h_norm = h_score / 100.0
76
+ p3 = (self.W_BERT * p_bert) + (self.W_GNN * p_gnn) + (self.W_HEURISTIC * h_norm)
77
+
78
+ logger.info(
79
+ f"Tier3 ensemble | url={url[:60]} | "
80
+ f"P_bert={p_bert:.4f} P_gnn={p_gnn:.4f} H_norm={h_norm:.4f} β†’ P3={p3:.4f}"
81
+ )
82
+
83
+ return round(min(max(p3, 0.0), 1.0), 4)
84
+
85
+ async def _bert_predict(
86
+ self,
87
+ url: str,
88
+ title: str,
89
+ snippet: str,
90
+ loop: asyncio.AbstractEventLoop,
91
+ ) -> float:
92
+ """
93
+ Run BERT inference in thread pool (non-blocking).
94
+ Returns P_bert ∈ [0,1].
95
+ """
96
+ try:
97
+ p_bert = await asyncio.wait_for(
98
+ loop.run_in_executor(
99
+ None, # Default thread pool
100
+ self._bert.predict,
101
+ url,
102
+ title,
103
+ snippet,
104
+ ),
105
+ timeout=10.0,
106
+ )
107
+ return float(p_bert)
108
+ except asyncio.TimeoutError:
109
+ logger.warning(f"BERT timeout for {url[:50]}")
110
+ return 0.5 # Neutral on timeout
111
+ except Exception as e:
112
+ logger.error(f"BERT predict error: {e}")
113
+ return 0.5
114
+
115
+ async def _gnn_predict(
116
+ self,
117
+ url: str,
118
+ loop: asyncio.AbstractEventLoop,
119
+ ) -> float:
120
+ """
121
+ Run GNN inference in thread pool (non-blocking).
122
+ Returns P_gnn ∈ [0,1].
123
+ """
124
+ try:
125
+ p_gnn = await asyncio.wait_for(
126
+ loop.run_in_executor(
127
+ None,
128
+ self._gnn.predict,
129
+ url,
130
+ None, # related_urls
131
+ ),
132
+ timeout=5.0,
133
+ )
134
+ return float(p_gnn)
135
+ except asyncio.TimeoutError:
136
+ logger.warning(f"GNN timeout for {url[:50]}")
137
+ return 0.5
138
+ except Exception as e:
139
+ logger.error(f"GNN predict error: {e}")
140
+ return 0.5
141
+
142
+ @staticmethod
143
+ def decide(p3: float) -> str:
144
+ """
145
+ Make decision based on P3 score.
146
+ Returns: 'block', 'safe', or 'escalate'
147
+ """
148
+ if p3 >= 0.85:
149
+ return "block"
150
+ elif p3 < 0.40:
151
+ return "safe"
152
+ else:
153
+ return "escalate"
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "model_max_length": 512,
49
+ "pad_token": "[PAD]",
50
+ "sep_token": "[SEP]",
51
+ "strip_accents": null,
52
+ "tokenize_chinese_chars": true,
53
+ "tokenizer_class": "BertTokenizer",
54
+ "unk_token": "[UNK]"
55
+ }
train_cnn.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - cnn/train_cnn.py
3
+ # CNN fine-tuning script for phishing screenshot detection.
4
+ #
5
+ # Loads data/screenshots/ with ImageFolder structure
6
+ # Augmentation: RandomHorizontalFlip, ColorJitter, RandomRotation
7
+ # 15 epochs, AdamW on head only (backbone stays frozen)
8
+ # Saves cnn_weights.pt + cnn_replay_buffer.pt
9
+ # Works with as few as 100 images per class
10
+ # ============================================================
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import sys
16
+ from pathlib import Path
17
+ from typing import List
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format="%(asctime)s | %(levelname)-7s | %(message)s",
22
+ )
23
+ logger = logging.getLogger("phishguard.cnn.train")
24
+
25
+ CNN_DIR = Path(__file__).parent
26
+ BACKEND_DIR = CNN_DIR.parent
27
+ WEIGHTS_PATH = CNN_DIR / "cnn_weights.pt"
28
+ REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "cnn_replay_buffer.pt"
29
+ SCREENSHOTS_DIR = BACKEND_DIR / "data" / "screenshots"
30
+
31
+ sys.path.insert(0, str(CNN_DIR))
32
+ sys.path.insert(0, str(BACKEND_DIR))
33
+
34
+
35
+ def main() -> None:
36
+ print("=" * 60)
37
+ print("PhishGuard AI β€” CNN Training")
38
+ print("=" * 60)
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ from torch.optim import AdamW
43
+ from torch.utils.data import DataLoader, Dataset, random_split
44
+ import torchvision.transforms as T
45
+ from PIL import Image
46
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
47
+
48
+ from cnn_model import PhishCNN
49
+
50
+ # ── Check data ───────────────────────────────────────────────
51
+ phishing_dir = SCREENSHOTS_DIR / "phishing"
52
+ legitimate_dir = SCREENSHOTS_DIR / "legitimate"
53
+
54
+ if not phishing_dir.exists() or not legitimate_dir.exists():
55
+ print(f"\n⚠️ Screenshot directories not found:")
56
+ print(f" Expected: {phishing_dir}")
57
+ print(f" Expected: {legitimate_dir}")
58
+ print(f"\n Run: python screenshot_collector.py")
59
+
60
+ # Create dirs and generate placeholder images for testing
61
+ phishing_dir.mkdir(parents=True, exist_ok=True)
62
+ legitimate_dir.mkdir(parents=True, exist_ok=True)
63
+
64
+ print(" Generating synthetic training images...")
65
+ _generate_synthetic_screenshots(phishing_dir, legitimate_dir)
66
+
67
+ phishing_files = list(phishing_dir.glob("*.png")) + list(phishing_dir.glob("*.jpg"))
68
+ legit_files = list(legitimate_dir.glob("*.png")) + list(legitimate_dir.glob("*.jpg"))
69
+
70
+ print(f"\nπŸ“Š Dataset:")
71
+ print(f" Phishing screenshots: {len(phishing_files)}")
72
+ print(f" Legitimate screenshots: {len(legit_files)}")
73
+
74
+ if len(phishing_files) < 10 or len(legit_files) < 10:
75
+ print("⚠️ Too few screenshots. Generating synthetic images...")
76
+ _generate_synthetic_screenshots(phishing_dir, legitimate_dir, count=100)
77
+ phishing_files = list(phishing_dir.glob("*.png"))
78
+ legit_files = list(legitimate_dir.glob("*.png"))
79
+ print(f" Phishing: {len(phishing_files)}, Legitimate: {len(legit_files)}")
80
+
81
+ # ── Dataset ──────────────────────────────────────────────────
82
+ train_transform = T.Compose([
83
+ T.Resize((224, 224)),
84
+ T.RandomHorizontalFlip(),
85
+ T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
86
+ T.RandomRotation(5),
87
+ T.ToTensor(),
88
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
89
+ ])
90
+
91
+ val_transform = T.Compose([
92
+ T.Resize((224, 224)),
93
+ T.ToTensor(),
94
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
95
+ ])
96
+
97
+ class ScreenshotDataset(Dataset):
98
+ def __init__(self, files: List[Path], label: int, transform):
99
+ self.files = files
100
+ self.label = label
101
+ self.transform = transform
102
+
103
+ def __len__(self) -> int:
104
+ return len(self.files)
105
+
106
+ def __getitem__(self, idx: int):
107
+ try:
108
+ img = Image.open(self.files[idx]).convert("RGB")
109
+ tensor = self.transform(img)
110
+ return tensor, self.label
111
+ except Exception:
112
+ # Return black image on error
113
+ tensor = torch.zeros(3, 224, 224)
114
+ return tensor, self.label
115
+
116
+ # Split: 80% train, 20% val
117
+ import random
118
+ random.shuffle(phishing_files)
119
+ random.shuffle(legit_files)
120
+
121
+ phish_split = int(len(phishing_files) * 0.8)
122
+ legit_split = int(len(legit_files) * 0.8)
123
+
124
+ train_phish = phishing_files[:phish_split]
125
+ val_phish = phishing_files[phish_split:]
126
+ train_legit = legit_files[:legit_split]
127
+ val_legit = legit_files[legit_split:]
128
+
129
+ train_dataset = (
130
+ ScreenshotDataset(train_phish, 1, train_transform)
131
+ + ScreenshotDataset(train_legit, 0, train_transform)
132
+ )
133
+ val_dataset = (
134
+ ScreenshotDataset(val_phish, 1, val_transform)
135
+ + ScreenshotDataset(val_legit, 0, val_transform)
136
+ )
137
+
138
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
139
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
140
+
141
+ # ── Model ────────────────────────────────────────────────────
142
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
+ print(f"\nπŸ€– Device: {device}")
144
+
145
+ model = PhishCNN(pretrained=True).to(device)
146
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
147
+ total = sum(p.numel() for p in model.parameters())
148
+ print(f" Parameters: {total:,} total, {trainable:,} trainable")
149
+
150
+ # Only optimize head parameters
151
+ head_params = [p for p in model.backbone.fc.parameters() if p.requires_grad]
152
+ optimizer = AdamW(head_params, lr=1e-3, weight_decay=1e-4)
153
+ loss_fn = nn.BCELoss()
154
+
155
+ # ── Training ─────────────────────────────────────────────────
156
+ EPOCHS = 2
157
+ best_val_acc = 0.0
158
+
159
+ print(f"\nπŸ‹οΈ Training for {EPOCHS} epochs...")
160
+ print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7}")
161
+ print(f" {'─'*5} | {'─'*8} | {'─'*9} | {'─'*7}")
162
+
163
+ for epoch in range(1, EPOCHS + 1):
164
+ # Train
165
+ model.train()
166
+ total_loss = 0.0
167
+ train_preds, train_labels = [], []
168
+
169
+ for batch_x, batch_y in train_loader:
170
+ batch_x = batch_x.to(device)
171
+ batch_y = batch_y.float().to(device)
172
+
173
+ optimizer.zero_grad()
174
+ output = model(batch_x).squeeze()
175
+ loss = loss_fn(output, batch_y)
176
+ loss.backward()
177
+ optimizer.step()
178
+
179
+ total_loss += loss.item()
180
+ preds = (output >= 0.5).int()
181
+ train_preds.extend(preds.cpu().tolist())
182
+ train_labels.extend(batch_y.int().cpu().tolist())
183
+
184
+ avg_loss = total_loss / max(len(train_loader), 1)
185
+ train_acc = accuracy_score(train_labels, train_preds) if train_labels else 0.0
186
+
187
+ # Validate
188
+ model.eval()
189
+ val_preds, val_labels = [], []
190
+ with torch.no_grad():
191
+ for batch_x, batch_y in val_loader:
192
+ batch_x = batch_x.to(device)
193
+ batch_y = batch_y.float().to(device)
194
+ output = model(batch_x).squeeze()
195
+ preds = (output >= 0.5).int()
196
+ val_preds.extend(preds.cpu().tolist())
197
+ val_labels.extend(batch_y.int().cpu().tolist())
198
+
199
+ val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0
200
+
201
+ if epoch % 3 == 0 or epoch == 1:
202
+ print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f}")
203
+
204
+ if val_acc > best_val_acc:
205
+ best_val_acc = val_acc
206
+ torch.save(model.state_dict(), WEIGHTS_PATH)
207
+
208
+ # ── Final metrics ────────────────────────────────────────────
209
+ if val_labels:
210
+ precision, recall, f1, _ = precision_recall_fscore_support(
211
+ val_labels, val_preds, average="binary", zero_division=0,
212
+ )
213
+ print(f"\nπŸ“Š Final Validation:")
214
+ print(f" Accuracy: {best_val_acc:.4f}")
215
+ print(f" Precision: {precision:.4f}")
216
+ print(f" Recall: {recall:.4f}")
217
+ print(f" F1 Score: {f1:.4f}")
218
+
219
+ # ── Save replay buffer ───────────────────────────────────────
220
+ all_paths = phishing_files + legit_files
221
+ replay_paths = [str(p) for p in all_paths[:100]]
222
+ replay_labels = [1] * min(len(phishing_files), 50) + [0] * min(len(legit_files), 50)
223
+ REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True)
224
+ torch.save({"paths": replay_paths, "labels": replay_labels}, REPLAY_BUFFER_PATH)
225
+
226
+ print(f"\nβœ… CNN weights saved to: {WEIGHTS_PATH}")
227
+ print(f"πŸ’Ύ Replay buffer saved: {len(replay_paths)} paths β†’ {REPLAY_BUFFER_PATH}")
228
+ print("=" * 60)
229
+
230
+
231
+ def _generate_synthetic_screenshots(
232
+ phishing_dir: Path,
233
+ legitimate_dir: Path,
234
+ count: int = 100,
235
+ ) -> None:
236
+ """Generate synthetic screenshots for training when real data unavailable."""
237
+ import random
238
+ from PIL import Image, ImageDraw, ImageFont
239
+
240
+ for label, save_dir, colors in [
241
+ ("phishing", phishing_dir, [(200, 50, 50), (180, 30, 30), (220, 80, 60)]),
242
+ ("legitimate", legitimate_dir, [(50, 120, 200), (30, 100, 180), (60, 140, 220)]),
243
+ ]:
244
+ save_dir.mkdir(parents=True, exist_ok=True)
245
+ existing = len(list(save_dir.glob("*.png")))
246
+ needed = max(0, count - existing)
247
+
248
+ for i in range(needed):
249
+ # Create varied synthetic images
250
+ w, h = 1280, 800
251
+ bg = random.choice(colors)
252
+ img = Image.new("RGB", (w, h), bg)
253
+ draw = ImageDraw.Draw(img)
254
+
255
+ # Add shapes
256
+ for _ in range(random.randint(5, 15)):
257
+ x1 = random.randint(0, w - 100)
258
+ y1 = random.randint(0, h - 100)
259
+ x2 = x1 + random.randint(50, 300)
260
+ y2 = y1 + random.randint(30, 200)
261
+ color = tuple(random.randint(0, 255) for _ in range(3))
262
+ draw.rectangle([x1, y1, x2, y2], fill=color)
263
+
264
+ # Add text-like rectangles
265
+ for _ in range(random.randint(3, 8)):
266
+ x = random.randint(100, w - 400)
267
+ y = random.randint(100, h - 100)
268
+ draw.rectangle([x, y, x + random.randint(100, 300), y + 20],
269
+ fill=(255, 255, 255))
270
+
271
+ img.save(save_dir / f"synthetic_{i:04d}.png")
272
+
273
+ logger.info(f"Generated synthetic screenshots in {phishing_dir.parent}")
274
+
275
+
276
+ if __name__ == "__main__":
277
+ main()
train_gnn.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - gnn/train_gnn.py
3
+ # Full GNN training script.
4
+ #
5
+ # Downloads PhishTank bz2 + TRANCO zip + Kaggle CSV mirror
6
+ # Builds training graphs, 40 epochs, saves gnn_weights.pt
7
+ # 70/15/15 train/val/test split with stratification
8
+ # Saves replay buffer to gnn_replay_buffer.pt
9
+ # ============================================================
10
+
11
+ from __future__ import annotations
12
+
13
+ import sys
14
+ import random
15
+ import logging
16
+ from pathlib import Path
17
+ from typing import List, Tuple
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format="%(asctime)s | %(levelname)-7s | %(message)s",
22
+ )
23
+ logger = logging.getLogger("phishguard.gnn.train")
24
+
25
+ # Paths
26
+ GNN_DIR = Path(__file__).parent
27
+ BACKEND_DIR = GNN_DIR.parent
28
+ WEIGHTS_PATH = GNN_DIR / "gnn_weights.pt"
29
+ REPLAY_BUFFER_PATH = BACKEND_DIR / "data" / "gnn_replay_buffer.pt"
30
+
31
+ # Add backend to path for imports
32
+ sys.path.insert(0, str(BACKEND_DIR))
33
+ sys.path.insert(0, str(GNN_DIR))
34
+
35
+
36
+ def main() -> None:
37
+ print("=" * 60)
38
+ print("PhishGuard AI β€” GNN Training")
39
+ print("=" * 60)
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
44
+
45
+ from domain_graph_builder import DomainGraphBuilder
46
+ from gnn_model import PhishGNN, PhishMLP, PYGEOM_AVAILABLE, INPUT_DIM
47
+
48
+ # ── Download data ────────────────────────────────────────────
49
+ from data_collector import download_phishtank, download_tranco, merge_datasets
50
+
51
+ print("\nπŸ“₯ Downloading datasets...")
52
+ phish_urls = download_phishtank(max_urls=50)
53
+ legit_urls = download_tranco(n=50)
54
+ print(f" Phishing URLs: {len(phish_urls)}")
55
+ print(f" Legitimate URLs: {len(legit_urls)}")
56
+
57
+ train_data, val_data, test_data = merge_datasets(phish_urls, legit_urls)
58
+
59
+ # ── Build graphs ─────────────────────────────────────────────
60
+ builder = DomainGraphBuilder()
61
+ CHUNK_SIZE = 4 # Group URLs into small graphs
62
+
63
+ def build_dataset(data: List[Tuple[str, int]], desc: str) -> list:
64
+ """Build graph dataset from (url, label) pairs."""
65
+ dataset = []
66
+ # Separate by label
67
+ phish = [url for url, label in data if label == 1]
68
+ legit = [url for url, label in data if label == 0]
69
+
70
+ for urls, label in [(phish, 1), (legit, 0)]:
71
+ for i in range(0, len(urls), CHUNK_SIZE):
72
+ chunk = urls[i : i + CHUNK_SIZE]
73
+ if not chunk:
74
+ continue
75
+ graph = builder.build_graph(chunk)
76
+ x = torch.tensor(graph["features"], dtype=torch.float)
77
+ edges = graph["edges"]
78
+ if edges and len(edges) > 0:
79
+ edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
80
+ else:
81
+ # Self-loops for graphs with no edges
82
+ n = x.size(0)
83
+ edge_index = torch.arange(n).unsqueeze(0).repeat(2, 1)
84
+ dataset.append({
85
+ "x": x,
86
+ "edge_index": edge_index,
87
+ "y": torch.tensor([float(label)]),
88
+ })
89
+
90
+ random.shuffle(dataset)
91
+ print(f" {desc}: {len(dataset)} graphs")
92
+ return dataset
93
+
94
+ print("\nπŸ”¨ Building graphs...")
95
+ train_graphs = build_dataset(train_data, "Train")
96
+ val_graphs = build_dataset(val_data, "Val")
97
+ test_graphs = build_dataset(test_data, "Test")
98
+
99
+ # ── Create model ─────────────────────────────────────────────
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ print(f"\nπŸ€– Device: {device}")
102
+
103
+ model = PhishGNN() if PYGEOM_AVAILABLE else PhishMLP()
104
+ model = model.to(device)
105
+ model_type = "GCN" if PYGEOM_AVAILABLE else "MLP"
106
+ print(f" Model: Phish{model_type}")
107
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
108
+
109
+ # ── Training ─────────────────────────────────────────────────
110
+ EPOCHS = 2
111
+ LR = 0.001
112
+ WEIGHT_DECAY = 1e-4
113
+
114
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
115
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
116
+ optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6,
117
+ )
118
+ loss_fn = F.binary_cross_entropy
119
+
120
+ best_val_acc = 0.0
121
+ best_epoch = 0
122
+
123
+ print(f"\nπŸ‹οΈ Training for {EPOCHS} epochs...")
124
+ print(f" {'Epoch':>5} | {'Loss':>8} | {'Train Acc':>9} | {'Val Acc':>7} | {'LR':>10}")
125
+ print(f" {'─' * 5} | {'─' * 8} | {'─' * 9} | {'─' * 7} | {'─' * 10}")
126
+
127
+ for epoch in range(1, EPOCHS + 1):
128
+ # ── Train ────────────────────────────────────────────────
129
+ model.train()
130
+ total_loss = 0.0
131
+ train_preds = []
132
+ train_labels = []
133
+
134
+ random.shuffle(train_graphs)
135
+ for item in train_graphs:
136
+ x = item["x"].to(device)
137
+ ei = item["edge_index"].to(device)
138
+ y = item["y"].to(device)
139
+
140
+ optimizer.zero_grad()
141
+ out = model(x, ei)
142
+ loss = loss_fn(out.squeeze(), y.squeeze())
143
+ loss.backward()
144
+ optimizer.step()
145
+
146
+ total_loss += loss.item()
147
+ pred = 1 if out.squeeze().item() >= 0.5 else 0
148
+ train_preds.append(pred)
149
+ train_labels.append(int(y.item()))
150
+
151
+ avg_loss = total_loss / max(len(train_graphs), 1)
152
+ train_acc = accuracy_score(train_labels, train_preds)
153
+
154
+ # ── Validate ─────────────────────────────────────────────
155
+ model.eval()
156
+ val_preds = []
157
+ val_labels = []
158
+
159
+ with torch.no_grad():
160
+ for item in val_graphs:
161
+ x = item["x"].to(device)
162
+ ei = item["edge_index"].to(device)
163
+ y = item["y"].to(device)
164
+
165
+ out = model(x, ei)
166
+ pred = 1 if out.squeeze().item() >= 0.5 else 0
167
+ val_preds.append(pred)
168
+ val_labels.append(int(y.item()))
169
+
170
+ val_acc = accuracy_score(val_labels, val_preds) if val_labels else 0.0
171
+ scheduler.step(avg_loss)
172
+ current_lr = optimizer.param_groups[0]["lr"]
173
+
174
+ # Print progress
175
+ if epoch % 5 == 0 or epoch == 1:
176
+ print(f" {epoch:>5} | {avg_loss:>8.4f} | {train_acc:>9.4f} | {val_acc:>7.4f} | {current_lr:>10.6f}")
177
+
178
+ # Save best model
179
+ if val_acc > best_val_acc:
180
+ best_val_acc = val_acc
181
+ best_epoch = epoch
182
+ torch.save(model.state_dict(), WEIGHTS_PATH)
183
+
184
+ print(f"\n Best val accuracy: {best_val_acc:.4f} at epoch {best_epoch}")
185
+
186
+ # ── Test ─────────────────────────────────────────────────────
187
+ # Reload best weights
188
+ model.load_state_dict(
189
+ torch.load(WEIGHTS_PATH, map_location=device, weights_only=True)
190
+ )
191
+ model.eval()
192
+
193
+ test_preds = []
194
+ test_labels = []
195
+ with torch.no_grad():
196
+ for item in test_graphs:
197
+ x = item["x"].to(device)
198
+ ei = item["edge_index"].to(device)
199
+ y = item["y"].to(device)
200
+
201
+ out = model(x, ei)
202
+ pred = 1 if out.squeeze().item() >= 0.5 else 0
203
+ test_preds.append(pred)
204
+ test_labels.append(int(y.item()))
205
+
206
+ test_acc = accuracy_score(test_labels, test_preds) if test_labels else 0.0
207
+ precision, recall, f1, _ = precision_recall_fscore_support(
208
+ test_labels, test_preds, average="binary", zero_division=0,
209
+ )
210
+
211
+ print(f"\nπŸ“Š Test Results:")
212
+ print(f" Accuracy: {test_acc:.4f}")
213
+ print(f" Precision: {precision:.4f}")
214
+ print(f" Recall: {recall:.4f}")
215
+ print(f" F1 Score: {f1:.4f}")
216
+
217
+ # ── Save replay buffer ───────────────────────────────────────
218
+ REPLAY_BUFFER_PATH.parent.mkdir(parents=True, exist_ok=True)
219
+ replay_buffer = train_graphs[:500] # Keep last 500 samples
220
+ torch.save(replay_buffer, REPLAY_BUFFER_PATH)
221
+ print(f"\nπŸ’Ύ Replay buffer saved: {len(replay_buffer)} samples β†’ {REPLAY_BUFFER_PATH}")
222
+
223
+ print(f"\nβœ… GNN weights saved to: {WEIGHTS_PATH}")
224
+ print("=" * 60)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()
url_heuristics.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - url_heuristics.py
3
+ # Tier 2: Heuristic Rule Engine β€” 15 independent signals
4
+ #
5
+ # Pure Python regex ONLY β€” ZERO I/O, ZERO ML, ZERO network
6
+ # All regex patterns precompiled in __init__ for < 2ms latency
7
+ # Max raw score: 135 β†’ normalized to 0-100
8
+ # Decision: score >= 80 β†’ BLOCK | < 80 β†’ pass to Tier 3
9
+ # ============================================================
10
+
11
+ from __future__ import annotations
12
+
13
+ import re
14
+ import math
15
+ import time
16
+ from dataclasses import dataclass, field
17
+ from typing import List, Tuple
18
+ from urllib.parse import urlparse, parse_qs
19
+
20
+
21
+ @dataclass
22
+ class HeuristicResult:
23
+ """Result from the Tier 2 heuristic scoring engine."""
24
+ score: int # 0-100 normalized score
25
+ signals: List[str] = field(default_factory=list) # human-readable triggered rules
26
+ raw_score: int = 0 # pre-normalization total (max 135)
27
+
28
+
29
+ MAX_RAW_SCORE: int = 135
30
+
31
+
32
+ class HeuristicScorer:
33
+ """
34
+ Tier 2 Heuristic Rule Engine.
35
+
36
+ Scores URLs 0-100 across 15 independent regex/math signals.
37
+ All regex patterns are precompiled in __init__ (called once at startup).
38
+ The score() method runs all 15 checks in < 2ms on standard hardware.
39
+ """
40
+
41
+ def __init__(self) -> None:
42
+ # ── Precompile ALL regex patterns (called once) ──────────────
43
+ self._re_ip_hostname = re.compile(
44
+ r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$"
45
+ )
46
+
47
+ self._suspicious_tlds: frozenset[str] = frozenset({
48
+ ".xyz", ".tk", ".ml", ".ga", ".cf",
49
+ ".gq", ".pw", ".top", ".click",
50
+ })
51
+
52
+ self._phishing_keywords: Tuple[str, ...] = (
53
+ "login", "verify", "secure", "update", "account",
54
+ "banking", "signin", "reset", "confirm", "suspend",
55
+ "webscr", "cmd", "payment", "alert",
56
+ )
57
+ self._re_phishing_keywords = re.compile(
58
+ r"(?:" + "|".join(re.escape(kw) for kw in self._phishing_keywords) + r")",
59
+ re.IGNORECASE,
60
+ )
61
+
62
+ self._brand_names: Tuple[str, ...] = (
63
+ "paypal", "google", "apple", "microsoft", "amazon",
64
+ "netflix", "facebook", "instagram", "chase",
65
+ "wellsfargo", "bankofamerica",
66
+ )
67
+ self._brand_legitimate_domains: dict[str, frozenset[str]] = {
68
+ brand: frozenset({
69
+ f"{brand}.com", f"www.{brand}.com",
70
+ f"{brand}.org", f"{brand}.net",
71
+ })
72
+ for brand in self._brand_names
73
+ }
74
+ self._re_brands = re.compile(
75
+ r"(?:" + "|".join(re.escape(b) for b in self._brand_names) + r")",
76
+ re.IGNORECASE,
77
+ )
78
+
79
+ self._re_non_standard_port = re.compile(
80
+ r":(\d+)", re.IGNORECASE,
81
+ )
82
+ self._standard_ports: frozenset[int] = frozenset({80, 443, 8080})
83
+
84
+ self._re_double_slash = re.compile(r"(?<=.)//")
85
+
86
+ self._re_url_encoded = re.compile(r"%[0-9A-Fa-f]{2}")
87
+
88
+ # ── Public API ───────────────────────────────────────────────────
89
+
90
+ def score(self, url: str) -> HeuristicResult:
91
+ """
92
+ Score a raw URL string from 0-100 for phishing probability.
93
+ Runs all 15 checks. Returns HeuristicResult with signals.
94
+ """
95
+ raw_score: int = 0
96
+ signals: List[str] = []
97
+
98
+ # Parse URL once, reuse across all checks
99
+ try:
100
+ parsed = urlparse(url if "://" in url else f"http://{url}")
101
+ except Exception:
102
+ return HeuristicResult(score=0, signals=["parse_error"], raw_score=0)
103
+
104
+ hostname: str = (parsed.hostname or "").lower()
105
+ path: str = parsed.path or ""
106
+ query: str = parsed.query or ""
107
+ url_lower: str = url.lower()
108
+
109
+ # Extract domain (without subdomains) for brand check
110
+ host_parts = hostname.split(".")
111
+ domain = ".".join(host_parts[-2:]) if len(host_parts) >= 2 else hostname
112
+
113
+ # Run all 15 checks
114
+ checks: List[Tuple[int, str]] = [
115
+ self._check_ip_hostname(hostname),
116
+ self._check_suspicious_tld(hostname),
117
+ self._check_phishing_keywords(url_lower),
118
+ self._check_brand_spoofing(url_lower, domain),
119
+ self._check_subdomain_depth(hostname),
120
+ self._check_url_length(url),
121
+ self._check_domain_length(hostname),
122
+ self._check_hyphen_count(hostname),
123
+ self._check_digit_ratio(hostname),
124
+ self._check_shannon_entropy(hostname),
125
+ self._check_non_standard_port(parsed.netloc),
126
+ self._check_double_slash_redirect(path),
127
+ self._check_url_encoding(url),
128
+ self._check_query_length(query),
129
+ self._check_path_depth(path),
130
+ ]
131
+
132
+ for points, signal in checks:
133
+ if points > 0:
134
+ raw_score += points
135
+ signals.append(signal)
136
+
137
+ # Normalize: raw_score / 135 * 100, capped at 100
138
+ normalized = min(round(raw_score / MAX_RAW_SCORE * 100), 100)
139
+
140
+ return HeuristicResult(
141
+ score=normalized,
142
+ signals=signals,
143
+ raw_score=raw_score,
144
+ )
145
+
146
+ # ── 15 Individual Signal Checks ──────────────────────────────────
147
+
148
+ def _check_ip_hostname(self, hostname: str) -> Tuple[int, str]:
149
+ """Signal 1: IP address as hostname (25 points)."""
150
+ if self._re_ip_hostname.match(hostname):
151
+ return 25, "IP address as hostname"
152
+ return 0, ""
153
+
154
+ def _check_suspicious_tld(self, hostname: str) -> Tuple[int, str]:
155
+ """Signal 2: Suspicious/cheap TLD (20 points)."""
156
+ for tld in self._suspicious_tlds:
157
+ if hostname.endswith(tld):
158
+ return 20, f"Suspicious TLD ({tld})"
159
+ return 0, ""
160
+
161
+ def _check_phishing_keywords(self, url_lower: str) -> Tuple[int, str]:
162
+ """Signal 3: Phishing keywords in URL (15 points)."""
163
+ matches = self._re_phishing_keywords.findall(url_lower)
164
+ if matches:
165
+ unique = set(m.lower() for m in matches)
166
+ return 15, f"Phishing keywords: {', '.join(sorted(unique))}"
167
+ return 0, ""
168
+
169
+ def _check_brand_spoofing(self, url_lower: str, domain: str) -> Tuple[int, str]:
170
+ """Signal 4: Brand name in URL but wrong domain (15 points)."""
171
+ matches = self._re_brands.findall(url_lower)
172
+ for brand_match in matches:
173
+ brand = brand_match.lower()
174
+ legit_domains = self._brand_legitimate_domains.get(brand, frozenset())
175
+ # Check if the domain IS the legitimate brand domain
176
+ if domain not in legit_domains and f"www.{domain}" not in legit_domains:
177
+ return 15, f"Brand spoofing: '{brand}' on non-brand domain"
178
+ return 0, ""
179
+
180
+ def _check_subdomain_depth(self, hostname: str) -> Tuple[int, str]:
181
+ """Signal 5: Excessive subdomains >= 3 (10 points)."""
182
+ parts = hostname.split(".")
183
+ subdomain_count = max(0, len(parts) - 2)
184
+ if subdomain_count >= 3:
185
+ return 10, f"Excessive subdomains ({subdomain_count})"
186
+ return 0, ""
187
+
188
+ def _check_url_length(self, url: str) -> Tuple[int, str]:
189
+ """Signal 6: URL length > 100 chars (5 points)."""
190
+ if len(url) > 100:
191
+ return 5, f"Long URL ({len(url)} chars)"
192
+ return 0, ""
193
+
194
+ def _check_domain_length(self, hostname: str) -> Tuple[int, str]:
195
+ """Signal 7: Domain length > 30 chars (5 points)."""
196
+ if len(hostname) > 30:
197
+ return 5, f"Long domain ({len(hostname)} chars)"
198
+ return 0, ""
199
+
200
+ def _check_hyphen_count(self, hostname: str) -> Tuple[int, str]:
201
+ """Signal 8: Hyphen count >= 3 in domain (5 points)."""
202
+ count = hostname.count("-")
203
+ if count >= 3:
204
+ return 5, f"Excessive hyphens in domain ({count})"
205
+ return 0, ""
206
+
207
+ def _check_digit_ratio(self, hostname: str) -> Tuple[int, str]:
208
+ """Signal 9: Digit ratio in domain > 0.3 (5 points)."""
209
+ if not hostname:
210
+ return 0, ""
211
+ digits = sum(1 for c in hostname if c.isdigit())
212
+ ratio = digits / len(hostname)
213
+ if ratio > 0.3:
214
+ return 5, f"High digit ratio in domain ({ratio:.2f})"
215
+ return 0, ""
216
+
217
+ def _check_shannon_entropy(self, hostname: str) -> Tuple[int, str]:
218
+ """Signal 10: High Shannon entropy > 3.5 (5 points)."""
219
+ if not hostname:
220
+ return 0, ""
221
+ length = len(hostname)
222
+ freq: dict[str, int] = {}
223
+ for c in hostname:
224
+ freq[c] = freq.get(c, 0) + 1
225
+ entropy = -sum(
226
+ (count / length) * math.log2(count / length)
227
+ for count in freq.values()
228
+ if count > 0
229
+ )
230
+ if entropy > 3.5:
231
+ return 5, f"High entropy domain ({entropy:.2f})"
232
+ return 0, ""
233
+
234
+ def _check_non_standard_port(self, netloc: str) -> Tuple[int, str]:
235
+ """Signal 11: Non-standard port in URL (5 points)."""
236
+ match = self._re_non_standard_port.search(netloc)
237
+ if match:
238
+ port = int(match.group(1))
239
+ if port not in self._standard_ports:
240
+ return 5, f"Non-standard port (:{port})"
241
+ return 0, ""
242
+
243
+ def _check_double_slash_redirect(self, path: str) -> Tuple[int, str]:
244
+ """Signal 12: Double slash redirect in path (5 points)."""
245
+ if self._re_double_slash.search(path):
246
+ return 5, "Double-slash redirect in path"
247
+ return 0, ""
248
+
249
+ def _check_url_encoding(self, url: str) -> Tuple[int, str]:
250
+ """Signal 13: URL-encoded characters > 5 (5 points)."""
251
+ encoded_chars = self._re_url_encoded.findall(url)
252
+ if len(encoded_chars) > 5:
253
+ return 5, f"Excessive URL encoding ({len(encoded_chars)} encoded chars)"
254
+ return 0, ""
255
+
256
+ def _check_query_length(self, query: str) -> Tuple[int, str]:
257
+ """Signal 14: Query string length > 200 (3 points)."""
258
+ if len(query) > 200:
259
+ return 3, f"Long query string ({len(query)} chars)"
260
+ return 0, ""
261
+
262
+ def _check_path_depth(self, path: str) -> Tuple[int, str]:
263
+ """Signal 15: Path depth > 6 levels (2 points)."""
264
+ segments = [s for s in path.split("/") if s]
265
+ if len(segments) > 6:
266
+ return 2, f"Deep path ({len(segments)} levels)"
267
+ return 0, ""
268
+
269
+
270
+ # ── Legacy compatibility wrapper ─────────────────────────────────────
271
+ _default_scorer = HeuristicScorer()
272
+
273
+
274
+ def analyze_url(url: str) -> dict:
275
+ """
276
+ Legacy-compatible wrapper around HeuristicScorer.
277
+ Returns dict with 'score', 'flags', 'is_suspicious' for backward compat.
278
+ """
279
+ result = _default_scorer.score(url)
280
+ return {
281
+ "score": result.score,
282
+ "flags": result.signals,
283
+ "is_suspicious": result.score >= 40,
284
+ "raw_score": result.raw_score,
285
+ }
286
+
287
+
288
+ # ── Benchmark (run directly to test latency) ─────────────────────────
289
+ if __name__ == "__main__":
290
+ test_urls = [
291
+ "https://www.google.com",
292
+ "http://192.168.1.1/admin/login",
293
+ "https://paypal-secure-login.xyz/verify/account?id=12345",
294
+ "https://a.b.c.d.evil.com/login/secure/update/verify",
295
+ "https://secure-login-bank-now.tk:9999/path/a/b/c/d/e/f/g.php",
296
+ "http://xn--pypal-4ve.com/signin?token=%2F%40%3A%20%2F%40%3A%20extra&" + "a" * 200,
297
+ "https://www.amazon.com/dp/B0something",
298
+ "https://microsft-l0gin-verfy.ga/account/suspended//evil.com/reset",
299
+ ]
300
+
301
+ scorer = HeuristicScorer()
302
+ print("=" * 70)
303
+ print("PhishGuard Tier 2 β€” Heuristic Benchmark")
304
+ print("=" * 70)
305
+
306
+ times: list[float] = []
307
+ for url in test_urls:
308
+ start = time.perf_counter()
309
+ result = scorer.score(url)
310
+ elapsed_us = (time.perf_counter() - start) * 1_000_000
311
+
312
+ times.append(elapsed_us)
313
+ action = "BLOCK" if result.score >= 80 else "β†’ Tier 3"
314
+ print(f"\n URL: {url[:80]}...")
315
+ print(f" Score: {result.score}/100 (raw {result.raw_score}/{MAX_RAW_SCORE}) β†’ {action}")
316
+ if result.signals:
317
+ for sig in result.signals:
318
+ print(f" ⚑ {sig}")
319
+ print(f" Latency: {elapsed_us:.0f}Β΅s")
320
+
321
+ print("\n" + "=" * 70)
322
+ print(f" Avg latency: {sum(times)/len(times):.0f}Β΅s")
323
+ print(f" Max latency: {max(times):.0f}Β΅s")
324
+ print(f" Target: < 2000Β΅s (2ms)")
325
+ print(f" Result: {'βœ… PASS' if max(times) < 2000 else '❌ FAIL'}")
326
+ print("=" * 70)
visual_analyzer.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # PhishGuard AI - visual_analyzer.py
3
+ # Takes a screenshot of a webpage using a headless browser
4
+ # and analyzes it for visual phishing indicators.
5
+ #
6
+ # Screenshot parameters (from architecture doc 2.3):
7
+ # Viewport: 1280Γ—800 (standard desktop resolution)
8
+ # Timeout: 10s (prevent hanging on slow/malicious pages)
9
+ # Wait: domcontentloaded (faster than networkidle)
10
+ # Blocked: fonts, media, video (60-70% faster load)
11
+ # User-Agent: Chrome 120 string (avoid bot detection)
12
+ #
13
+ # Tier 4 is OPTIONAL β€” controlled by env var ENABLE_VISUAL_TIER.
14
+ # Set ENABLE_VISUAL_TIER=1 to enable.
15
+ # Unset / set 0 β†’ tier 4 is skipped with "tier4_disabled".
16
+ #
17
+ # Render.com: If deploying with Playwright, your render.yaml
18
+ # build command must install Chromium deps. See render.yaml
19
+ # comments and the Dockerfile for required apt packages.
20
+ #
21
+ # Latency budget: < 200ms for screenshot capture
22
+ # ============================================================
23
+
24
+ from __future__ import annotations
25
+
26
+ import os
27
+ import re
28
+ import time
29
+ import hashlib
30
+ import logging
31
+ from urllib.parse import urlparse
32
+
33
+ logger = logging.getLogger("phishguard.visual")
34
+
35
+ # ── Environment gate ─────────────────────────────────────────────────────────
36
+ ENABLE_VISUAL_TIER = os.environ.get("ENABLE_VISUAL_TIER", "0").strip() in ("1", "true", "yes")
37
+
38
+ if not ENABLE_VISUAL_TIER:
39
+ print("[PhishGuard] Tier 4 visual analysis DISABLED (set ENABLE_VISUAL_TIER=1 to enable)")
40
+
41
+ # ── Playwright availability ──────────────────────────────────────────────────
42
+ PLAYWRIGHT_AVAILABLE = False
43
+ if ENABLE_VISUAL_TIER:
44
+ try:
45
+ from playwright.async_api import async_playwright
46
+ PLAYWRIGHT_AVAILABLE = True
47
+ print("[PhishGuard] Playwright available β€” screenshot capture enabled")
48
+ except ImportError:
49
+ print("[PhishGuard] Playwright not installed β€” visual analysis will use heuristic-only mode")
50
+
51
+ # ── PIL availability ─────────────────────────────────────────────────────────
52
+ _pil_available = False
53
+ try:
54
+ from PIL import Image
55
+ import io as _io
56
+ _pil_available = True
57
+ except ImportError:
58
+ print("[PhishGuard] Pillow not available β€” color analysis disabled")
59
+
60
+
61
+ # ── Screenshot cache config ──────────────────────────────────────────────────
62
+ _CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "screenshots")
63
+ _CACHE_TTL = 24 * 60 * 60 # 24 hours in seconds
64
+
65
+ os.makedirs(_CACHE_DIR, exist_ok=True)
66
+
67
+
68
+ # ── Brand / financial keyword databases ──────────────────────────────────────
69
+ BRAND_DATABASE = {
70
+ # brand_keyword β†’ list of legitimate domains
71
+ "paypal": ["paypal.com"],
72
+ "apple": ["apple.com", "icloud.com"],
73
+ "google": ["google.com", "gmail.com", "accounts.google.com"],
74
+ "amazon": ["amazon.com", "amazon.co.uk", "aws.amazon.com"],
75
+ "microsoft": ["microsoft.com", "live.com", "outlook.com", "office.com"],
76
+ "netflix": ["netflix.com"],
77
+ "facebook": ["facebook.com", "fb.com"],
78
+ "instagram": ["instagram.com"],
79
+ "chase": ["chase.com"],
80
+ "wellsfargo": ["wellsfargo.com"],
81
+ "bankofamerica": ["bankofamerica.com"],
82
+ "citibank": ["citibank.com", "citi.com"],
83
+ "hsbc": ["hsbc.com"],
84
+ "hdfc": ["hdfcbank.com"],
85
+ "icici": ["icicibank.com"],
86
+ "sbi": ["onlinesbi.com", "sbi.co.in"],
87
+ }
88
+
89
+ FINANCIAL_BRANDS = {
90
+ "paypal", "chase", "wellsfargo", "bankofamerica", "citibank",
91
+ "hsbc", "hdfc", "icici", "sbi", "bank", "banking",
92
+ }
93
+
94
+
95
+ def _domain_hash(url: str) -> str:
96
+ """Generate a stable hash for screenshot caching based on the domain."""
97
+ try:
98
+ parsed = urlparse(url if url.startswith("http") else "http://" + url)
99
+ host = parsed.hostname or url
100
+ return hashlib.sha256(host.encode()).hexdigest()[:16]
101
+ except Exception:
102
+ return hashlib.sha256(url.encode()).hexdigest()[:16]
103
+
104
+
105
+ def _get_root_domain(url: str) -> str:
106
+ """Extract root domain from URL. E.g. https://login.paypal.com β†’ paypal.com"""
107
+ try:
108
+ parsed = urlparse(url if url.startswith("http") else "http://" + url)
109
+ host = (parsed.hostname or "").lower().replace("www.", "")
110
+ parts = host.split(".")
111
+ return ".".join(parts[-2:]) if len(parts) >= 2 else host
112
+ except Exception:
113
+ return ""
114
+
115
+
116
+ # ═════════════════════════════════════════════���════════════════════════════════
117
+ # SCREENSHOT CAPTURE (with cache)
118
+ # ══════════════════════════════════════════════════════════════════════════════
119
+
120
+ def _get_cached_screenshot(url: str) -> bytes | None:
121
+ """
122
+ Check if a cached screenshot exists for this domain and is < 24 hours old.
123
+ Returns the screenshot bytes or None.
124
+ """
125
+ dhash = _domain_hash(url)
126
+ cache_path = os.path.join(_CACHE_DIR, f"{dhash}.png")
127
+
128
+ if not os.path.exists(cache_path):
129
+ return None
130
+
131
+ # Check age
132
+ age = time.time() - os.path.getmtime(cache_path)
133
+ if age >= _CACHE_TTL:
134
+ # Expired β€” delete stale cache
135
+ try:
136
+ os.remove(cache_path)
137
+ except OSError:
138
+ pass
139
+ return None
140
+
141
+ try:
142
+ with open(cache_path, "rb") as f:
143
+ data = f.read()
144
+ logger.info(f"Screenshot cache HIT | url={url} | age={age:.0f}s")
145
+ return data
146
+ except Exception:
147
+ return None
148
+
149
+
150
+ def _save_screenshot_cache(url: str, data: bytes):
151
+ """Save screenshot bytes to cache as screenshots/<domain_hash>.png."""
152
+ try:
153
+ dhash = _domain_hash(url)
154
+ cache_path = os.path.join(_CACHE_DIR, f"{dhash}.png")
155
+ with open(cache_path, "wb") as f:
156
+ f.write(data)
157
+ logger.info(f"Screenshot cached | url={url} | path={cache_path}")
158
+ except Exception as e:
159
+ logger.warning(f"Screenshot cache write failed | error={e}")
160
+
161
+
162
+ async def take_screenshot(url: str) -> bytes | None:
163
+ """
164
+ Open the URL in a hidden (headless) browser and take a screenshot.
165
+ The user never sees this browser window.
166
+
167
+ Uses a 24-hour cache: if screenshots/<domain_hash>.png exists and is
168
+ fresh, returns cached bytes without launching a browser.
169
+
170
+ Returns: screenshot as bytes, or None if it fails.
171
+ """
172
+ # Gate: tier 4 disabled
173
+ if not ENABLE_VISUAL_TIER:
174
+ return None
175
+
176
+ # Check cache first
177
+ cached = _get_cached_screenshot(url)
178
+ if cached is not None:
179
+ return cached
180
+
181
+ # Playwright not available β€” can't take a fresh screenshot
182
+ if not PLAYWRIGHT_AVAILABLE:
183
+ logger.warning(f"Screenshot skipped (no Playwright) | url={url}")
184
+ return None
185
+
186
+ try:
187
+ async with async_playwright() as p:
188
+ browser = await p.chromium.launch(headless=True)
189
+ context = await browser.new_context(
190
+ viewport={"width": 1280, "height": 800},
191
+ ignore_https_errors=True,
192
+ user_agent=(
193
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
194
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
195
+ "Chrome/120.0.0.0 Safari/537.36"
196
+ )
197
+ )
198
+ page = await context.new_page()
199
+
200
+ # Block fonts and media to speed up loading (60-70% faster)
201
+ await page.route(
202
+ "**/*.{woff,woff2,ttf,mp4,mp3,wav}",
203
+ lambda route: route.abort()
204
+ )
205
+
206
+ await page.goto(url, timeout=10000, wait_until="domcontentloaded")
207
+
208
+ # ── Extract page metadata for heuristic analysis ──────────
209
+ page_title = await page.title() or ""
210
+ has_password_field = await page.locator("input[type='password']").count() > 0
211
+
212
+ screenshot = await page.screenshot(full_page=False)
213
+ await browser.close()
214
+
215
+ # Cache the screenshot for 24 hours
216
+ if screenshot:
217
+ _save_screenshot_cache(url, screenshot)
218
+
219
+ return screenshot
220
+
221
+ except Exception as e:
222
+ logger.error(f"Screenshot failed | url={url} | error={e}")
223
+ return None
224
+
225
+
226
+ async def take_screenshot_with_metadata(url: str) -> dict:
227
+ """
228
+ Enhanced screenshot capture that also extracts page metadata
229
+ (title, login forms) for heuristic visual scoring.
230
+
231
+ Returns: {
232
+ "screenshot": bytes|None,
233
+ "page_title": str,
234
+ "has_password_field": bool,
235
+ "uses_https": bool,
236
+ "error": str|None
237
+ }
238
+ """
239
+ result = {
240
+ "screenshot": None,
241
+ "page_title": "",
242
+ "has_password_field": False,
243
+ "uses_https": url.lower().startswith("https"),
244
+ "error": None,
245
+ }
246
+
247
+ # Gate: tier 4 disabled
248
+ if not ENABLE_VISUAL_TIER:
249
+ result["error"] = "tier4_disabled"
250
+ return result
251
+
252
+ # Check screenshot cache (metadata won't be cached, just the image)
253
+ cached = _get_cached_screenshot(url)
254
+ if cached is not None:
255
+ result["screenshot"] = cached
256
+ # We can't get page metadata from cache, but we have the image
257
+ return result
258
+
259
+ if not PLAYWRIGHT_AVAILABLE:
260
+ result["error"] = "playwright_not_available"
261
+ return result
262
+
263
+ try:
264
+ async with async_playwright() as p:
265
+ browser = await p.chromium.launch(headless=True)
266
+ context = await browser.new_context(
267
+ viewport={"width": 1280, "height": 800},
268
+ ignore_https_errors=True,
269
+ user_agent=(
270
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
271
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
272
+ "Chrome/120.0.0.0 Safari/537.36"
273
+ )
274
+ )
275
+ page = await context.new_page()
276
+
277
+ await page.route(
278
+ "**/*.{woff,woff2,ttf,mp4,mp3,wav}",
279
+ lambda route: route.abort()
280
+ )
281
+
282
+ await page.goto(url, timeout=10000, wait_until="domcontentloaded")
283
+
284
+ # Extract metadata
285
+ result["page_title"] = await page.title() or ""
286
+ result["has_password_field"] = await page.locator("input[type='password']").count() > 0
287
+
288
+ screenshot = await page.screenshot(full_page=False)
289
+ await browser.close()
290
+
291
+ result["screenshot"] = screenshot
292
+
293
+ # Cache the screenshot
294
+ if screenshot:
295
+ _save_screenshot_cache(url, screenshot)
296
+
297
+ except Exception as e:
298
+ result["error"] = str(e)
299
+ logger.error(f"Screenshot+metadata failed | url={url} | error={e}")
300
+
301
+ return result
302
+
303
+
304
+ # ══════════════════════════════════════════════════════════════════════════════
305
+ # VISUAL PHISHING HEURISTICS (no CNN needed)
306
+ # ══════════════════════════════════════════════════════════════════════════════
307
+
308
+ def analyze_visual_heuristic(url: str, page_title: str = "",
309
+ has_password_field: bool = False) -> dict:
310
+ """
311
+ Heuristic visual phishing scoring WITHOUT needing a trained CNN.
312
+ Returns heuristic_visual_score from 0.0 to 1.0 based on:
313
+
314
+ Signal 1: Page title contains brand names but domain doesn't match
315
+ Signal 2: Page has a login form (input[type=password])
316
+ Signal 3: SSL cert missing for pages mentioning financial brands
317
+ Signal 4: Brand keyword in URL path but not in domain (path spoofing)
318
+
319
+ Returns: {
320
+ heuristic_visual_score: float 0..1,
321
+ flags: list[str],
322
+ brand_mismatch: bool,
323
+ has_login_form: bool,
324
+ ssl_missing_financial: bool
325
+ }
326
+ """
327
+ score = 0.0
328
+ flags = []
329
+ brand_mismatch = False
330
+ ssl_missing_financial = False
331
+ root_domain = _get_root_domain(url)
332
+ url_lower = url.lower()
333
+ title_lower = (page_title or "").lower()
334
+ uses_https = url_lower.startswith("https")
335
+
336
+ # ── Signal 1: Brand name in page title but domain doesn't match ───────
337
+ for brand, legit_domains in BRAND_DATABASE.items():
338
+ if brand in title_lower:
339
+ if not any(d in root_domain for d in legit_domains):
340
+ score += 0.30
341
+ flags.append(f"title_brand_mismatch:{brand}")
342
+ brand_mismatch = True
343
+ break # One brand mismatch is enough
344
+
345
+ # ── Signal 2: Login form detected (input[type=password]) ──────────────
346
+ if has_password_field:
347
+ score += 0.15
348
+ flags.append("has_password_field")
349
+ # Extra risk if combined with brand mismatch
350
+ if brand_mismatch:
351
+ score += 0.15
352
+ flags.append("login_form_with_brand_mismatch")
353
+
354
+ # ── Signal 3: No SSL for financial brand content ──────────────────────
355
+ mentions_financial = any(
356
+ fb in title_lower or fb in url_lower
357
+ for fb in FINANCIAL_BRANDS
358
+ )
359
+ if mentions_financial and not uses_https:
360
+ score += 0.25
361
+ flags.append("no_ssl_financial_content")
362
+ ssl_missing_financial = True
363
+
364
+ # ── Signal 4: Brand keyword in URL path but not in domain ─────────────
365
+ try:
366
+ parsed = urlparse(url)
367
+ path = (parsed.path or "").lower()
368
+ for brand, legit_domains in BRAND_DATABASE.items():
369
+ if brand in path and not any(d in root_domain for d in legit_domains):
370
+ score += 0.15
371
+ flags.append(f"brand_in_path_not_domain:{brand}")
372
+ break
373
+ except Exception:
374
+ pass
375
+
376
+ return {
377
+ "heuristic_visual_score": round(min(score, 1.0), 4),
378
+ "flags": flags,
379
+ "brand_mismatch": brand_mismatch,
380
+ "has_login_form": has_password_field,
381
+ "ssl_missing_financial": ssl_missing_financial,
382
+ }
383
+
384
+
385
+ def analyze_visual_basic(screenshot_bytes: bytes, url: str) -> dict:
386
+ """
387
+ Basic visual analysis using color histograms.
388
+ Detects if a page uses colors associated with known brands
389
+ but the URL doesn't match that brand.
390
+
391
+ Note: For full CNN analysis, see cnn/cnn_model.py
392
+ """
393
+ if not screenshot_bytes:
394
+ return {"visual_risk": 0.1, "note": "screenshot_failed"}
395
+
396
+ if not _pil_available:
397
+ return {"visual_risk": 0.1, "note": "pil_not_available"}
398
+
399
+ try:
400
+ img = Image.open(_io.BytesIO(screenshot_bytes)).convert("RGB")
401
+ img_small = img.resize((224, 224))
402
+
403
+ # Get average color channels
404
+ r_vals = list(img_small.split()[0].getdata())
405
+ g_vals = list(img_small.split()[1].getdata())
406
+ b_vals = list(img_small.split()[2].getdata())
407
+
408
+ r_avg = sum(r_vals) / len(r_vals)
409
+ g_avg = sum(g_vals) / len(g_vals)
410
+ b_avg = sum(b_vals) / len(b_vals)
411
+
412
+ risk = 0.2 # baseline
413
+ url_lower = url.lower()
414
+
415
+ # PayPal brand colors: deep blue
416
+ if b_avg > r_avg * 1.4 and b_avg > g_avg * 1.3:
417
+ if "paypal" not in url_lower:
418
+ risk += 0.25
419
+
420
+ # Microsoft brand colors: orange/blue
421
+ if r_avg > 180 and b_avg < 100:
422
+ if "microsoft" not in url_lower and "office" not in url_lower:
423
+ risk += 0.20
424
+
425
+ # Apple brand: mostly white/grey
426
+ if r_avg > 220 and g_avg > 220 and b_avg > 220:
427
+ if "apple" not in url_lower:
428
+ risk += 0.10
429
+
430
+ return {
431
+ "visual_risk": round(min(risk, 1.0), 4),
432
+ "dominant_rgb": [round(r_avg), round(g_avg), round(b_avg)],
433
+ "note": "basic_color_analysis"
434
+ }
435
+
436
+ except Exception as e:
437
+ return {"visual_risk": 0.1, "note": "analysis_error"}
438
+
439
+
440
+ # ══════════════════════════════════════════════════════════════════════════════
441
+ # FULL TIER 4 ANALYSIS (combines CNN + heuristics + color)
442
+ # ══════════════════════════════════════════════════════════════════════════════
443
+
444
+ async def run_tier4_analysis(url: str, page_title: str = "",
445
+ page_snippet: str = "") -> dict:
446
+ """
447
+ Complete Tier 4 visual analysis pipeline.
448
+ Called by main.py for borderline cases (0.40 ≀ P₃ < 0.85).
449
+
450
+ Graceful fallback chain:
451
+ 1. If ENABLE_VISUAL_TIER is off β†’ tier4_disabled
452
+ 2. If screenshot fails β†’ screenshot_failed (with heuristic fallback)
453
+ 3. If CNN fails β†’ uses heuristic_visual_score only
454
+
455
+ Returns: {
456
+ tier4_score: float|None,
457
+ tier4_status: str ("ok"|"screenshot_failed"|"tier4_disabled"|...),
458
+ tier4_reason: str,
459
+ visual_heuristic: dict,
460
+ color_analysis: dict,
461
+ screenshot_cached: bool
462
+ }
463
+ """
464
+ # ── Gate: completely skip if not enabled ───────────────────────────────
465
+ if not ENABLE_VISUAL_TIER:
466
+ return {
467
+ "tier4_score": None,
468
+ "tier4_status": "tier4_disabled",
469
+ "tier4_reason": "ENABLE_VISUAL_TIER env var not set",
470
+ }
471
+
472
+ # ── Attempt screenshot with metadata extraction ───────────────────────
473
+ meta = await take_screenshot_with_metadata(url)
474
+ screenshot = meta["screenshot"]
475
+ extracted_title = meta["page_title"] or page_title
476
+ has_password = meta["has_password_field"]
477
+ screenshot_error = meta["error"]
478
+
479
+ # ── Always run visual heuristics (no screenshot needed) ───────────────
480
+ heuristic = analyze_visual_heuristic(
481
+ url,
482
+ page_title=extracted_title,
483
+ has_password_field=has_password,
484
+ )
485
+
486
+ # ── Screenshot failed β†’ return heuristic-only result ──────────────────
487
+ if screenshot is None:
488
+ reason = screenshot_error or "unknown_screenshot_error"
489
+ return {
490
+ "tier4_score": heuristic["heuristic_visual_score"],
491
+ "tier4_status": "screenshot_failed",
492
+ "tier4_reason": reason,
493
+ "visual_heuristic": heuristic,
494
+ "color_analysis": None,
495
+ "screenshot_cached": False,
496
+ }
497
+
498
+ # ── Color-based analysis (works without trained CNN) ──────────────────
499
+ color = analyze_visual_basic(screenshot, url)
500
+
501
+ # ── Combine heuristic + color into a single tier4 score ───────────────
502
+ # Weight: 60% heuristic, 40% color (since CNN isn't trained)
503
+ combined = (heuristic["heuristic_visual_score"] * 0.60) + (color["visual_risk"] * 0.40)
504
+
505
+ return {
506
+ "tier4_score": round(min(combined, 1.0), 4),
507
+ "tier4_status": "ok",
508
+ "tier4_reason": "heuristic_and_color_analysis",
509
+ "visual_heuristic": heuristic,
510
+ "color_analysis": color,
511
+ "screenshot_cached": _get_cached_screenshot(url) is not None,
512
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff