Spaces:
Running
Running
| # ============================================================ | |
| # PhishGuard AI - bert_analyzer.py | |
| # Tier 3a: BERT NLP Phishing Classifier | |
| # | |
| # Model: ealvaradob/bert-finetuned-phishing (HuggingFace Hub) | |
| # Tokenization: split on [-./=?&_~%@] to preserve homoglyphs | |
| # Input: "URL: {tokenized_url}. Title: {title}. Content: {snippet}" | |
| # Output: P_bert β [0,1] | |
| # Supports: load, predict, fine-tune, incremental_update, save/load | |
| # ============================================================ | |
| from __future__ import annotations | |
| import re | |
| import math | |
| import logging | |
| import threading | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional, Dict | |
| logger = logging.getLogger("phishguard.bert") | |
| # ββ Model state ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _classifier = None | |
| _tokenizer = None | |
| _model = None | |
| _use_bert: bool = False | |
| _bert_load_attempted: bool = False | |
| _bert_lock = threading.Lock() | |
| # Check if transformers library is installed | |
| _transformers_available: bool = False | |
| try: | |
| import transformers as _tf_module | |
| _transformers_available = True | |
| logger.info("transformers library found β BERT will lazy-load on first call") | |
| except ImportError: | |
| logger.info("transformers not installed β using keyword NLP fallback") | |
| # ββ Phishing pattern databases (for keyword fallback) ββββββββββββββββ | |
| PHISHING_TERMS = [ | |
| "verify your account", "suspended", "click here immediately", | |
| "unusual activity", "confirm your identity", "limited time", | |
| "your password has been", "unauthorized access", "act now", | |
| "secure your account", "login credentials", "reset password immediately", | |
| "your account will be", "verify your identity", "we noticed suspicious", | |
| ] | |
| PHISHING_KEYWORDS = [ | |
| "login", "secure", "verify", "account", "update", "confirm", | |
| "banking", "paypal", "signin", "password", "suspend", "alert", | |
| "restore", "unusual", "limited", "expire", "urgent", "immediately", | |
| ] | |
| BRAND_NAMES = [ | |
| "paypal", "google", "apple", "microsoft", "amazon", "netflix", | |
| "facebook", "instagram", "twitter", "linkedin", "chase", "wells", | |
| "bankofamerica", "citibank", "usps", "fedex", "ebay", | |
| ] | |
| class BERTPhishingClassifier: | |
| """ | |
| BERT-based phishing text classifier. | |
| Wraps HuggingFace model with URL-aware tokenization. | |
| """ | |
| DEFAULT_MODEL = "ealvaradob/bert-finetuned-phishing" | |
| FALLBACK_MODEL = "mrm8488/bert-tiny-finetuned-sms-spam-detection" | |
| def __init__(self, model_name: Optional[str] = None) -> None: | |
| import os | |
| self.model_name: str = model_name or os.environ.get("HF_BERT_REPO") or self.DEFAULT_MODEL | |
| self._pipeline = None | |
| self._tokenizer = None | |
| self._model = None | |
| self._loaded: bool = False | |
| self._lock = threading.Lock() | |
| self._re_url_split = re.compile(r"[-./=?&_~%@:]+") | |
| def load_model(self) -> None: | |
| """Load BERT model from HuggingFace Hub with cache fallback.""" | |
| if self._loaded: | |
| return | |
| with self._lock: | |
| if self._loaded: | |
| return | |
| if not _transformers_available: | |
| logger.warning("transformers not available, BERT disabled") | |
| return | |
| try: | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| # Try primary model, fall back to smaller model | |
| for model_id in [self.model_name, self.FALLBACK_MODEL]: | |
| try: | |
| self._pipeline = pipeline( | |
| "text-classification", | |
| model=model_id, | |
| truncation=True, | |
| max_length=512, | |
| device=-1, | |
| ) | |
| self._tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self._model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
| self.model_name = model_id | |
| self._loaded = True | |
| logger.info(f"BERT model loaded: {model_id}") | |
| return | |
| except Exception as e: | |
| logger.warning(f"Failed to load {model_id}: {e}") | |
| continue | |
| logger.error("All BERT model candidates failed") | |
| except Exception as e: | |
| logger.error(f"BERT initialization failed: {e}") | |
| def tokenize_url(self, url: str) -> str: | |
| """ | |
| Split URL on [-./=?&_~%@:] to preserve homoglyphs. | |
| Example: "paypa1-l0gin.xyz/verify" β "paypa1 l0gin xyz verify" | |
| """ | |
| text = url.replace("https://", "").replace("http://", "") | |
| tokens = self._re_url_split.split(text) | |
| return " ".join(t for t in tokens if t) | |
| def predict(self, url: str, title: str = "", snippet: str = "") -> float: | |
| """ | |
| Predict phishing probability for a URL + page context. | |
| Returns P_bert β [0,1]. | |
| """ | |
| self.load_model() | |
| if self._loaded and self._pipeline is not None: | |
| return self._predict_bert(url, title, snippet) | |
| return self._predict_keyword(url, title, snippet) | |
| def _predict_bert(self, url: str, title: str, snippet: str) -> float: | |
| """BERT model prediction path.""" | |
| url_text = self.tokenize_url(url) | |
| combined = f"URL: {url_text}. Title: {title}. Content: {snippet[:300]}" | |
| result = self._pipeline(combined[:512])[0] | |
| label = result["label"].upper() | |
| confidence = result["score"] | |
| # Map label to phishing probability | |
| if any(kw in label for kw in ["SPAM", "PHISH", "MALICIOUS", "LABEL_1", "1"]): | |
| raw_prob = confidence | |
| else: | |
| raw_prob = 1.0 - confidence | |
| # Boost with keyword signals | |
| text_lower = combined.lower() | |
| phrase_hits = sum(1 for p in PHISHING_TERMS if p in text_lower) | |
| adjusted = min(raw_prob + (phrase_hits * 0.05), 1.0) | |
| return round(adjusted, 4) | |
| def _predict_keyword(self, url: str, title: str, snippet: str) -> float: | |
| """Keyword-based fallback when BERT is unavailable.""" | |
| combined = f"{url} {title} {snippet}".lower() | |
| url_lower = url.lower() | |
| score = 0.0 | |
| # Keyword hits in URL | |
| kw_hits = sum(1 for kw in PHISHING_KEYWORDS if kw in url_lower) | |
| score += min(kw_hits * 0.08, 0.40) | |
| # Phrase matches in content | |
| phrase_hits = sum(1 for p in PHISHING_TERMS if p in combined) | |
| score += min(phrase_hits * 0.12, 0.48) | |
| # Brand spoofing | |
| for brand in BRAND_NAMES: | |
| if brand in url_lower: | |
| if f"{brand}.com" not in url_lower: | |
| score += 0.20 | |
| break | |
| # IP as hostname | |
| if re.match(r"https?://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", url): | |
| score += 0.20 | |
| # Shannon entropy of hostname | |
| try: | |
| from urllib.parse import urlparse | |
| host = urlparse(url if "://" in url else f"http://{url}").hostname or "" | |
| if host: | |
| length = len(host) | |
| freq: Dict[str, int] = {} | |
| for c in host: | |
| freq[c] = freq.get(c, 0) + 1 | |
| entropy = -sum( | |
| (cnt / length) * math.log2(cnt / length) for cnt in freq.values() | |
| ) | |
| if entropy > 3.5: | |
| score += 0.10 | |
| except Exception: | |
| pass | |
| return round(min(score, 1.0), 4) | |
| def incremental_update( | |
| self, | |
| samples: List[Tuple[str, int]], | |
| lr: float = 1e-5, | |
| epochs: int = 1, | |
| label_smoothing: float = 0.1, | |
| ) -> Optional[float]: | |
| """ | |
| Incremental update: unfreeze last 2 transformer layers only. | |
| Returns accuracy_delta (float) or None if update failed. | |
| samples: list of (url, label) where label is 0 or 1 | |
| """ | |
| if not self._loaded or self._model is None or self._tokenizer is None: | |
| logger.warning("BERT not loaded, cannot incrementally update") | |
| return None | |
| if len(samples) < 5: | |
| logger.warning(f"Too few samples ({len(samples)}) for BERT update") | |
| return None | |
| try: | |
| import torch | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torch.optim import AdamW | |
| device = torch.device("cpu") | |
| model = self._model.to(device) | |
| # Freeze all layers | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze last 2 transformer layers + classifier | |
| if hasattr(model, "bert"): | |
| encoder_layers = model.bert.encoder.layer | |
| for layer in encoder_layers[-2:]: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| if hasattr(model, "classifier"): | |
| for param in model.classifier.parameters(): | |
| param.requires_grad = True | |
| # Prepare data | |
| texts = [self.tokenize_url(url) for url, _ in samples] | |
| labels = [label for _, label in samples] | |
| encodings = self._tokenizer( | |
| texts, truncation=True, padding=True, max_length=512, | |
| return_tensors="pt" | |
| ) | |
| label_tensor = torch.tensor(labels, dtype=torch.long).to(device) | |
| dataset = TensorDataset( | |
| encodings["input_ids"].to(device), | |
| encodings["attention_mask"].to(device), | |
| label_tensor, | |
| ) | |
| batch_size = min(len(samples), 16) | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
| # Pre-update accuracy | |
| model.eval() | |
| with torch.no_grad(): | |
| pre_correct = 0 | |
| for batch in loader: | |
| ids, mask, labs = batch | |
| outputs = model(input_ids=ids, attention_mask=mask) | |
| preds = torch.argmax(outputs.logits, dim=1) | |
| pre_correct += (preds == labs).sum().item() | |
| pre_acc = pre_correct / len(samples) | |
| # Train | |
| optimizer = AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=lr, | |
| ) | |
| loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing) | |
| model.train() | |
| for epoch in range(epochs): | |
| total_loss = 0.0 | |
| for batch in loader: | |
| ids, mask, labs = batch | |
| optimizer.zero_grad() | |
| outputs = model(input_ids=ids, attention_mask=mask) | |
| loss = loss_fn(outputs.logits, labs) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| logger.info(f"BERT incremental epoch {epoch+1}/{epochs}, loss={total_loss/len(loader):.4f}") | |
| # Post-update accuracy | |
| model.eval() | |
| with torch.no_grad(): | |
| post_correct = 0 | |
| for batch in loader: | |
| ids, mask, labs = batch | |
| outputs = model(input_ids=ids, attention_mask=mask) | |
| preds = torch.argmax(outputs.logits, dim=1) | |
| post_correct += (preds == labs).sum().item() | |
| post_acc = post_correct / len(samples) | |
| delta = post_acc - pre_acc | |
| self._model = model | |
| logger.info(f"BERT incremental update: {pre_acc:.4f} β {post_acc:.4f} (Ξ={delta:+.4f})") | |
| return round(delta, 4) | |
| except Exception as e: | |
| logger.error(f"BERT incremental update failed: {e}") | |
| return None | |
| def save(self, path: Path) -> None: | |
| """Save model and tokenizer to directory.""" | |
| if self._model and self._tokenizer: | |
| path = Path(path) | |
| path.mkdir(parents=True, exist_ok=True) | |
| self._model.save_pretrained(str(path)) | |
| self._tokenizer.save_pretrained(str(path)) | |
| logger.info(f"BERT model saved to {path}") | |
| def load_local(self, path: Path) -> bool: | |
| """Load model from local directory.""" | |
| path = Path(path) | |
| if not path.exists(): | |
| return False | |
| try: | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| self._tokenizer = AutoTokenizer.from_pretrained(str(path)) | |
| self._model = AutoModelForSequenceClassification.from_pretrained(str(path)) | |
| self._pipeline = pipeline( | |
| "text-classification", | |
| model=self._model, | |
| tokenizer=self._tokenizer, | |
| truncation=True, | |
| max_length=512, | |
| device=-1, | |
| ) | |
| self._loaded = True | |
| logger.info(f"BERT model loaded from {path}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"BERT local load failed: {e}") | |
| return False | |
| def is_loaded(self) -> bool: | |
| return self._loaded | |
| # ββ Legacy compatibility βββββββββββββββββββββββββββββββββββββββββββββ | |
| _default_classifier = BERTPhishingClassifier() | |
| def analyze_text(url: str, page_title: str = "", page_snippet: str = "") -> dict: | |
| """Legacy wrapper for backward compatibility with main.py.""" | |
| prob = _default_classifier.predict(url, page_title, page_snippet) | |
| return { | |
| "bert_phishing_prob": prob, | |
| "phrase_hits": 0, | |
| "label": "BERT" if _default_classifier.is_loaded else "KEYWORD_NLP", | |
| "confidence": prob, | |
| } | |
| def shannon_entropy(s: str) -> float: | |
| """Utility: measure randomness of a string.""" | |
| if not s: | |
| return 0.0 | |
| prob = [s.count(c) / len(s) for c in set(s)] | |
| return -sum(p * math.log2(p) for p in prob if p > 0) | |