Spaces:
Sleeping
Sleeping
| import os, re, json, time, math | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Tuple, Optional | |
| import gradio as gr | |
| # Optional imports for email classifier (loaded lazily). | |
| # Space still runs if these aren't available (pure lexical fallback). | |
| try: | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| except Exception: | |
| torch = None | |
| AutoTokenizer = None | |
| AutoModelForSequenceClassification = None | |
| # ========================= | |
| # Config (env-overridable) | |
| # ========================= | |
| EMAIL_CLASSIFIER_ID = os.getenv("EMAIL_CLASSIFIER_ID", "your-username/mini-phish") # <- swap to your HF repo when ready | |
| EMAIL_BACKBONE_ID = os.getenv("EMAIL_BACKBONE_ID", "microsoft/MiniLM-L6-H384-uncased") | |
| THRESHOLD_TAU = float(os.getenv("THRESHOLD_TAU", "0.40")) | |
| MAX_SEQ_LEN = int(os.getenv("MAX_SEQ_LEN", "320")) | |
| SUBJECT_TOKEN_BUDGET = int(os.getenv("SUBJECT_TOKEN_BUDGET", "64")) | |
| FUSION_EMAIL_W = float(os.getenv("FUSION_EMAIL_W", "0.6")) | |
| FUSION_URL_W = float(os.getenv("FUSION_URL_W", "0.4")) | |
| URL_OVERRIDE_HIGH = float(os.getenv("URL_OVERRIDE_HIGH", "0.85")) | |
| URL_OVERRIDE_KW = float(os.getenv("URL_OVERRIDE_KW", "0.70")) | |
| ALLOWLIST_SAFE_CAP = float(os.getenv("ALLOWLIST_SAFE_CAP", "0.15")) | |
| # ========================= | |
| # Simple data classes | |
| # ========================= | |
| class UrlResult: | |
| url: str | |
| risk: float | |
| reasons: List[str] | |
| contrib: Dict[str, float] # per‑reason contribution for transparency | |
| class EmailResult: | |
| p_email: float # final probability after boosts | |
| kw_hits: List[str] | |
| strong_hits: List[str] # subset of kw_hits considered strong | |
| token_counts: Dict[str, int] # {"subject_tokens":..,"body_tokens":..,"sequence_len":..} | |
| p_raw: Optional[float] # raw model probability (before boosts); None in lexical fallback | |
| path: Optional[str] # "classifier" | "backbone" | None (lexical) | |
| # ========================= | |
| # URL extraction & heuristics (swap with your real URL model when ready) | |
| # ========================= | |
| URL_REGEX = r'(?i)\b((?:https?://|www\.)[^\s<>")]+)' | |
| SUSPICIOUS_TLDS = { | |
| ".xyz", ".top", ".click", ".link", ".ru", ".cn", ".country", ".gq", ".ga", ".ml", ".tk" | |
| } | |
| SHORTENERS = {"bit.ly","t.co","tinyurl.com","goo.gl","ow.ly","is.gd","cutt.ly","tiny.one","lnkd.in"} | |
| def extract_urls(text: str) -> List[str]: | |
| if not text: return [] | |
| urls = re.findall(URL_REGEX, text) | |
| uniq, seen = [], set() | |
| for u in urls: | |
| u = u.strip().strip(').,;\'"') | |
| if u and u not in seen: | |
| uniq.append(u) | |
| seen.add(u) | |
| return uniq | |
| def url_host(url: str) -> str: | |
| host = re.sub(r"^https?://", "", url, flags=re.I).split("/")[0].lower() | |
| return host | |
| def score_url_heuristic(url: str) -> UrlResult: | |
| """ | |
| Heuristic scoring with a transparent per‑reason contribution map. | |
| This keeps the POC explainable and makes the Forensics panel richer. | |
| """ | |
| host = url_host(url) | |
| score = 0.0 | |
| reasons = [] | |
| contrib = {} | |
| def add(amount: float, tag: str): | |
| nonlocal score | |
| score += amount | |
| reasons.append(tag) | |
| contrib[tag] = round(contrib.get(tag, 0.0) + amount, 3) | |
| base = 0.05 | |
| add(base, "base") | |
| if len(url) > 140: | |
| add(0.15, "very_long_url") | |
| if "@" in url or "%" in url: | |
| add(0.20, "special_chars") | |
| if any(host.endswith(t) for t in SUSPICIOUS_TLDS): | |
| add(0.35, "suspicious_tld") | |
| if any(s in host for s in SHORTENERS): | |
| add(0.50, "shortener") | |
| if host.count(".") >= 3: | |
| add(0.20, "deep_subdomain") | |
| if len(re.findall(r"[A-Z]", url)) > 16: | |
| add(0.10, "mixed_case") | |
| score = min(score, 1.0) | |
| return UrlResult(url=url, risk=score, reasons=reasons, contrib=contrib) | |
| def score_urls(urls: List[str]) -> List[UrlResult]: | |
| return [score_url_heuristic(u) for u in urls] | |
| # ========================= | |
| # Email classifier with fallback | |
| # ========================= | |
| _tokenizer = None | |
| _model = None | |
| _model_loaded_from = None # "classifier", "backbone", or None | |
| _model_load_ms = None | |
| _model_quantized = False | |
| # Strong vs normal cues (lowercase) | |
| STRONG_CUES = [ | |
| "otp", "one-time password", "one time password", "cvv", "pin", "pan", | |
| "password", "bank details", "netbanking", "debit card", "credit card", | |
| "lottery", "jackpot", "prize", "reward", "winner", "you have won", | |
| "send otp", "share otp", "confirm otp", "verify otp", | |
| "account restricted", "reactivate account", "unlock your account" | |
| ] | |
| NORMAL_CUES = [ | |
| "verify your account", "update your password", "immediately", | |
| "within 24 hours", "suspended", "unusual activity", "confirm", | |
| "login", "click", "invoice", "payment", "security alert", | |
| "urgent", "limited time" | |
| ] | |
| LEXICAL_CUES = sorted(set(STRONG_CUES + NORMAL_CUES)) | |
| def load_email_model() -> Tuple[object, object, str]: | |
| """Try to load EMAIL_CLASSIFIER_ID; on failure, fall back to backbone with small head. | |
| Apply dynamic int8 quantization for CPU if available.""" | |
| global _tokenizer, _model, _model_loaded_from, _model_load_ms, _model_quantized | |
| if _tokenizer is not None and _model is not None: | |
| return _tokenizer, _model, _model_loaded_from | |
| start = time.perf_counter() | |
| if AutoTokenizer is None or AutoModelForSequenceClassification is None or torch is None: | |
| _model_loaded_from = None | |
| _model_load_ms = round((time.perf_counter() - start) * 1000, 2) | |
| return None, None, _model_loaded_from # environment without torch/transformers | |
| # Preferred classifier | |
| try: | |
| _tokenizer = AutoTokenizer.from_pretrained(EMAIL_CLASSIFIER_ID) | |
| _model = AutoModelForSequenceClassification.from_pretrained(EMAIL_CLASSIFIER_ID) | |
| _model_loaded_from = "classifier" | |
| except Exception: | |
| # Fallback: backbone + fresh 2-class head | |
| try: | |
| _tokenizer = AutoTokenizer.from_pretrained(EMAIL_BACKBONE_ID) | |
| _model = AutoModelForSequenceClassification.from_pretrained( | |
| EMAIL_BACKBONE_ID, num_labels=2, problem_type="single_label_classification" | |
| ) | |
| _model_loaded_from = "backbone" | |
| except Exception: | |
| _tokenizer, _model, _model_loaded_from = None, None, None | |
| _model_load_ms = round((time.perf_counter() - start) * 1000, 2) | |
| return None, None, _model_loaded_from | |
| # Dynamic quantization (CPU) | |
| _model_quantized = False | |
| try: | |
| _model.eval() | |
| _model.to("cpu") | |
| if hasattr(torch, "quantization"): | |
| from torch.quantization import quantize_dynamic | |
| _model = quantize_dynamic(_model, {torch.nn.Linear}, dtype=torch.qint8) # type: ignore | |
| _model_quantized = True | |
| except Exception: | |
| pass | |
| _model_load_ms = round((time.perf_counter() - start) * 1000, 2) | |
| return _tokenizer, _model, _model_loaded_from | |
| def _truncate_for_budget(tokens_subject: List[int], tokens_body: List[int], max_len: int, subj_budget: int): | |
| subj = tokens_subject[:subj_budget] | |
| remain = max(0, max_len - len(subj)) | |
| body = tokens_body[:remain] | |
| return subj + body | |
| def score_email(subject: str, body: str) -> Tuple[EmailResult, Dict]: | |
| """Return EmailResult + debug dict with probability, hits, boosts, timings, token counts, and model info.""" | |
| dbg = {"path": None, "p_raw": None, "boost_from_strong": 0.0, "boost_from_normal": 0.0, | |
| "timing_ms": {}, "token_counts": {}, "model_info": {}} | |
| t0 = time.perf_counter() | |
| text = (subject or "") + "\n" + (body or "") | |
| low = text.lower() | |
| strong_hits = [c for c in STRONG_CUES if c in low] | |
| normal_hits = [c for c in NORMAL_CUES if c in low] | |
| all_hits = sorted(set(strong_hits + normal_hits)) | |
| tok, mdl, path = load_email_model() | |
| dbg["path"] = path | |
| dbg["model_info"] = { | |
| "loaded_from": path, | |
| "classifier_id": EMAIL_CLASSIFIER_ID, | |
| "backbone_id": EMAIL_BACKBONE_ID, | |
| "quantized": _model_quantized, | |
| "model_load_ms": _model_load_ms | |
| } | |
| if tok is None or mdl is None: | |
| # Pure lexical fallback (no model available): | |
| base = 0.10 | |
| p_email = base + 0.18 * len(strong_hits) + 0.07 * len(normal_hits) | |
| p_email = float(max(0.01, min(0.99, p_email))) | |
| dbg["p_raw"] = None | |
| dbg["boost_from_strong"] = 0.18 * len(strong_hits) | |
| dbg["boost_from_normal"] = 0.07 * len(normal_hits) | |
| dbg["timing_ms"]["email_infer"] = round((time.perf_counter() - t0) * 1000, 2) | |
| dbg["token_counts"] = {"subject_tokens": 0, "body_tokens": 0, "sequence_len": 0} | |
| return EmailResult( | |
| p_email=p_email, kw_hits=all_hits, strong_hits=strong_hits, | |
| token_counts=dbg["token_counts"], p_raw=None, path=path | |
| ), dbg | |
| # Model path (MiniLM or your classifier) | |
| enc_t0 = time.perf_counter() | |
| encoded_subj = tok.encode(subject or "", add_special_tokens=False) | |
| encoded_body = tok.encode(body or "", add_special_tokens=False) | |
| input_ids = _truncate_for_budget(encoded_subj, encoded_body, MAX_SEQ_LEN - 2, SUBJECT_TOKEN_BUDGET) | |
| input_ids = [tok.cls_token_id] + input_ids + [tok.sep_token_id] | |
| attn_mask = [1] * len(input_ids) | |
| ids = torch.tensor([input_ids], dtype=torch.long) | |
| mask = torch.tensor([attn_mask], dtype=torch.long) | |
| with torch.no_grad(): | |
| out = mdl(input_ids=ids, attention_mask=mask) | |
| if hasattr(out, "logits"): | |
| logits = out.logits[0].detach().cpu().numpy().tolist() | |
| exps = [math.exp(x) for x in logits] | |
| p_raw = float(exps[1] / (exps[0] + exps[1])) # assume label 1 = phishing | |
| else: | |
| p_raw = 0.5 | |
| # Nudge with cues: stronger boost for strong hits | |
| boost_s = 0.10 * len(strong_hits) | |
| boost_n = 0.03 * len(normal_hits) | |
| p_email = float(max(0.01, min(0.99, p_raw + boost_s + boost_n))) | |
| dbg["p_raw"] = round(p_raw, 3) | |
| dbg["boost_from_strong"] = round(boost_s, 3) | |
| dbg["boost_from_normal"] = round(boost_n, 3) | |
| dbg["timing_ms"]["email_infer"] = round((time.perf_counter() - enc_t0) * 1000, 2) | |
| dbg["token_counts"] = { | |
| "subject_tokens": len(encoded_subj), | |
| "body_tokens": len(encoded_body), | |
| "sequence_len": len(input_ids) | |
| } | |
| return EmailResult( | |
| p_email=p_email, kw_hits=all_hits, strong_hits=strong_hits, | |
| token_counts=dbg["token_counts"], p_raw=p_raw, path=path | |
| ), dbg | |
| # ========================= | |
| # Fusion | |
| # ========================= | |
| def fuse(email_res: EmailResult, url_results: List[UrlResult], allowlist_domains: List[str]) -> Tuple[Dict, Dict]: | |
| """Return fused decision dict + debug dict explaining the math & overrides.""" | |
| fdbg = { | |
| "weights": {"email": FUSION_EMAIL_W, "url": FUSION_URL_W}, | |
| "threshold_tau": THRESHOLD_TAU, | |
| "overrides": {"url_high": URL_OVERRIDE_HIGH, "url_kw": URL_OVERRIDE_KW, "allowlist_safe_cap": ALLOWLIST_SAFE_CAP}, | |
| "applied_overrides": [], | |
| } | |
| r_url_max = max([u.risk for u in url_results], default=0.0) | |
| no_urls = (len(url_results) == 0) | |
| # Allowlist check | |
| allowlist_hit = False | |
| matched_allow = None | |
| for u in url_results: | |
| h = url_host(u.url) | |
| for d in [d.strip().lower() for d in allowlist_domains if d.strip()]: | |
| if h.endswith(d): | |
| allowlist_hit = True | |
| matched_allow = d | |
| break | |
| if allowlist_hit: | |
| break | |
| # Base fusion | |
| r_before = FUSION_EMAIL_W * email_res.p_email + FUSION_URL_W * r_url_max | |
| # URL-driven overrides | |
| kw_flag = 1 if email_res.kw_hits else 0 | |
| r_after = r_before | |
| if r_url_max >= URL_OVERRIDE_HIGH: | |
| r_after = max(r_after, 0.90) | |
| fdbg["applied_overrides"].append("URL_OVERRIDE_HIGH") | |
| elif kw_flag and r_url_max >= URL_OVERRIDE_KW: | |
| r_after = max(r_after, 0.90) | |
| fdbg["applied_overrides"].append("URL_OVERRIDE_KW") | |
| # Email-only strong-cue override | |
| if no_urls and len(email_res.strong_hits) > 0: | |
| r_after = max(r_after, 0.85) | |
| fdbg["applied_overrides"].append("EMAIL_ONLY_STRONG_CUES") | |
| # Allowlist cap | |
| if allowlist_hit: | |
| r_after = min(r_after, ALLOWLIST_SAFE_CAP) | |
| fdbg["applied_overrides"].append(f"ALLOWLIST({matched_allow})") | |
| verdict = "UNSAFE" if r_after >= THRESHOLD_TAU else "SAFE" | |
| fused = { | |
| "P_email": round(email_res.p_email, 3), | |
| "P_email_raw": round(email_res.p_raw, 3) if email_res.p_raw is not None else None, | |
| "R_url_max": round(r_url_max, 3), | |
| "R_total": round(r_after, 3), | |
| "R_total_before_overrides": round(r_before, 3), | |
| "kw_hits": email_res.kw_hits, | |
| "strong_hits": email_res.strong_hits, | |
| "token_counts": email_res.token_counts, | |
| "no_urls": no_urls, | |
| "allowlist_hit": allowlist_hit, | |
| "verdict": verdict | |
| } | |
| fdbg.update({ | |
| "components": {"P_email": fused["P_email"], "R_url_max": fused["R_url_max"]}, | |
| "no_urls": no_urls, | |
| "allowlist_hit": allowlist_hit, | |
| "matched_allow": matched_allow | |
| }) | |
| return fused, fdbg | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| with gr.Blocks(title="PhishingMail-Lab") as demo: | |
| gr.Markdown("# 🧪 PhishingMail‑Lab\n**POC** — Hybrid (email + URL) with explainable forensics.") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| subject = gr.Textbox(label="Subject", placeholder="Subject: Important account update") | |
| body = gr.Textbox(label="Email Body (paste text or HTML)", lines=12, placeholder="Paste the email content here...") | |
| with gr.Row(): | |
| allowlist = gr.Textbox(label="Allowlist domains (comma-separated)", placeholder="microsoft.com, amazon.in") | |
| tau = gr.Slider(0, 1, value=THRESHOLD_TAU, step=0.01, label="Decision Threshold τ") | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| verdict = gr.Label(label="Verdict") | |
| # Banner under verdict | |
| context_banner = gr.Markdown(visible=False) | |
| fusion_json = gr.JSON(label="Fusion & Flags") | |
| url_table = gr.Dataframe(headers=["URL","Risk","Reasons"], label="Per‑URL risk (heuristics demo)", interactive=False) | |
| # Forensics column | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 🔎 Forensics") | |
| forensics_json = gr.JSON(label="Forensics (structured log)") | |
| forensics_md = gr.Markdown(label="Forensics (human‑readable)") | |
| def run(subject_text, body_text, allowlist_text, tau_val): | |
| # Timers for forensics | |
| t_all = time.perf_counter() | |
| # Update threshold | |
| global THRESHOLD_TAU | |
| THRESHOLD_TAU = float(tau_val) | |
| # URL pipeline | |
| t0 = time.perf_counter() | |
| raw_text = (subject_text or "") + "\n" + (body_text or "") | |
| urls = list(dict.fromkeys(extract_urls(raw_text))) # uniq & ordered | |
| t1 = time.perf_counter() | |
| url_results = score_urls(urls) | |
| t2 = time.perf_counter() | |
| # Email pipeline | |
| email_res, email_dbg = score_email(subject_text or "", body_text or "") | |
| # Fusion | |
| allow_domains = [d.strip().lower() for d in (allowlist_text or "").split(",") if d.strip()] | |
| fused, fuse_dbg = fuse(email_res, url_results, allow_domains) | |
| # Build banner text/visibility | |
| banners = [] | |
| if fused.get("no_urls"): | |
| banners.append("⚠️ **No URLs found** — decision based **only on email body**.") | |
| if fused.get("allowlist_hit"): | |
| banners.append("🛈 **Allowlist active** — risk **capped** for trusted domain.") | |
| banner_text = "<br>".join(banners) if banners else "" | |
| banner_visible = bool(banners) | |
| # Forensics JSON (deeper detail) | |
| per_url = [{ | |
| "url": u.url, | |
| "risk": round(u.risk,3), | |
| "reasons": u.reasons, | |
| "contrib": u.contrib | |
| } for u in url_results] | |
| fx = { | |
| "config": { | |
| "weights": {"email": FUSION_EMAIL_W, "url": FUSION_URL_W}, | |
| "threshold_tau": THRESHOLD_TAU, | |
| "overrides": { | |
| "url_high": URL_OVERRIDE_HIGH, | |
| "url_kw": URL_OVERRIDE_KW, | |
| "allowlist_safe_cap": ALLOWLIST_SAFE_CAP | |
| }, | |
| "model_ids": {"classifier": EMAIL_CLASSIFIER_ID, "backbone": EMAIL_BACKBONE_ID} | |
| }, | |
| "input_summary": { | |
| "chars_subject": len(subject_text or ""), | |
| "chars_body": len(body_text or ""), | |
| "num_urls": len(urls), | |
| "allowlist_domains": allow_domains | |
| }, | |
| "email": { | |
| "path": email_dbg["path"] or "lexical-fallback", | |
| "p_email_final": fused["P_email"], | |
| "p_email_raw": email_dbg["p_raw"], | |
| "boost_from_strong": email_dbg["boost_from_strong"], | |
| "boost_from_normal": email_dbg["boost_from_normal"], | |
| "token_counts": email_dbg["token_counts"], | |
| "kw_hits": email_res.kw_hits, | |
| "strong_hits": email_res.strong_hits, | |
| "model_info": email_dbg["model_info"] | |
| }, | |
| "urls": per_url, | |
| "fusion": { | |
| "equation": f"R_total = {FUSION_EMAIL_W} * P_email + {FUSION_URL_W} * R_url_max", | |
| "values": { | |
| "P_email": fused["P_email"], | |
| "R_url_max": fused["R_url_max"], | |
| "R_total_before_overrides": fused["R_total_before_overrides"], | |
| "R_total_final": fused["R_total"], | |
| "overrides_applied": fuse_dbg["applied_overrides"] | |
| }, | |
| "decision": { | |
| "threshold_tau": THRESHOLD_TAU, | |
| "verdict": fused["verdict"] | |
| }, | |
| "flags": { | |
| "no_urls": fused["no_urls"], | |
| "allowlist_hit": fused["allowlist_hit"] | |
| } | |
| }, | |
| "timings_ms": { | |
| "model_load": email_dbg["model_info"]["model_load_ms"], | |
| "url_extract": round((t1 - t0) * 1000, 2), | |
| "url_score": round((t2 - t1) * 1000, 2), | |
| "email_infer": email_dbg["timing_ms"].get("email_infer"), | |
| "total": round((time.perf_counter() - t_all) * 1000, 2) | |
| } | |
| } | |
| # Forensics Markdown (human‑readable, denser detail) | |
| lines = [] | |
| lines.append(f"**Verdict:** `{fused['verdict']}` | **R_total:** `{fused['R_total']}` (before: `{fused['R_total_before_overrides']}`) | **τ:** `{THRESHOLD_TAU}`") | |
| lines.append(f"**Fusion:** R = {FUSION_EMAIL_W}×P_email + {FUSION_URL_W}×R_url_max → {FUSION_EMAIL_W}×{fused['P_email']} + {FUSION_URL_W}×{fused['R_url_max']}") | |
| if fuse_dbg["applied_overrides"]: | |
| lines.append(f"**Overrides:** {', '.join(fuse_dbg['applied_overrides'])}") | |
| else: | |
| lines.append("**Overrides:** (none)") | |
| if fused["no_urls"]: | |
| lines.append("• No URLs found → email‑only decision path.") | |
| if fused["allowlist_hit"]: | |
| lines.append("• Allowlist matched → risk capped.") | |
| lines.append("") | |
| lines.append(f"**Email path:** `{email_dbg['path'] or 'lexical-fallback'}` | p_raw={email_dbg['p_raw']} | +strong={email_dbg['boost_from_strong']} | +normal={email_dbg['boost_from_normal']}") | |
| tc = email_dbg["token_counts"] | |
| lines.append(f"• Tokens: subject={tc.get('subject_tokens',0)}, body={tc.get('body_tokens',0)}, sequence_len={tc.get('sequence_len',0)} (max={MAX_SEQ_LEN}) | subject_budget={SUBJECT_TOKEN_BUDGET}") | |
| if email_res.strong_hits: | |
| lines.append(f"• Strong cues: {', '.join(email_res.strong_hits)}") | |
| if email_res.kw_hits: | |
| lines.append(f"• All cues: {', '.join(email_res.kw_hits)}") | |
| lines.append("") | |
| if per_url: | |
| lines.append("**URLs & contributions:**") | |
| for u in per_url: | |
| contrib_str = ", ".join([f"{k}:{v}" for k,v in u["contrib"].items()]) | |
| lines.append(f"• {u['url']} → risk={u['risk']} | reasons=({', '.join(u['reasons']) or 'none'}) | contrib=({contrib_str or 'n/a'})") | |
| else: | |
| lines.append("**URLs:** (none)") | |
| lines.append("") | |
| lines.append(f"**Model info:** loaded_from={email_dbg['model_info']['loaded_from']}, quantized={email_dbg['model_info']['quantized']}, load_ms={email_dbg['model_info']['model_load_ms']}") | |
| lines.append("") | |
| lines.append("**Timings (ms):** " + json.dumps(fx["timings_ms"])) | |
| forensic_markdown = "\n".join(lines) | |
| rows = [[u.url, round(u.risk,3), ", ".join(u.reasons)] for u in url_results] | |
| return ( | |
| fused["verdict"], | |
| gr.update(value=banner_text, visible=banner_visible), | |
| fused, | |
| rows, | |
| fx, | |
| forensic_markdown | |
| ) | |
| analyze_btn.click( | |
| run, | |
| [subject, body, allowlist, tau], | |
| [verdict, context_banner, fusion_json, url_table, forensics_json, forensics_md] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |