Spaces:
Running
Running
| # src/ensemble_router.py | |
| # SupportMind β Ensemble Confidence-Gated Router | |
| # Combines DistilBERT (MC Dropout) + TF-IDF Logistic Regression | |
| # for best-in-class accuracy on ticket routing. | |
| # | |
| # Strategy: weighted soft-voting on probability distributions | |
| # final_probs = w_bert * bert_probs + w_sklearn * sklearn_probs | |
| # | |
| # Why this beats either model alone: | |
| # - DistilBERT: captures semantic meaning, handles paraphrases | |
| # - TF-IDF+LR : captures keyword/n-gram signals, very confident on clear cases | |
| # - Ensemble : DistilBERT corrects LR on ambiguous tickets, | |
| # LR corrects BERT on keyword-heavy ones | |
| import os | |
| import gc | |
| import pickle | |
| import logging | |
| import numpy as np | |
| from typing import Dict, Optional | |
| logger = logging.getLogger(__name__) | |
| # ββ Category map ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CATEGORY_MAP = { | |
| 0: 'billing', | |
| 1: 'technical_support', | |
| 2: 'account_management', | |
| 3: 'feature_request', | |
| 4: 'compliance_legal', | |
| 5: 'onboarding', | |
| 6: 'general_inquiry', | |
| 7: 'churn_risk', | |
| } | |
| CATEGORY_REVERSE = {v: k for k, v in CATEGORY_MAP.items()} | |
| # ββ Routing thresholds βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ROUTE_THRESHOLD = 0.82 # ensemble conf >= this β auto-route | |
| CLARIFY_THRESHOLD = 0.58 # ensemble conf >= this β ask 1 question | |
| ENTROPY_MAX = 0.32 # ensemble entropy <= this β low ambiguity | |
| MC_PASSES = 10 # MC Dropout stochastic passes (sequential for memory) | |
| # ββ Ensemble weights βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BERT weight is higher because it generalises better to unseen phrasing. | |
| # These are tunable β increase SKLEARN_W if LR is more accurate on your data. | |
| # BERT weight is significantly higher because DeBERTa-v3 is extremely robust. | |
| BERT_W = 0.75 | |
| SKLEARN_W = 0.25 | |
| class EnsembleRouter: | |
| """ | |
| Ensemble Confidence-Gated Router. | |
| Combines: | |
| 1. DistilBERT fine-tuned on support tickets (MC Dropout for uncertainty) | |
| 2. TF-IDF + Calibrated Logistic Regression baseline | |
| Falls back to sklearn-only if DistilBERT model weights are absent. | |
| Drop-in replacement for ConfidenceGatedRouter β same .route() interface. | |
| """ | |
| def __init__(self, model_dir: Optional[str] = None, device: str = 'cpu'): | |
| base = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| ultimate_path = os.path.join(base, 'models', 'deberta_ultimate') | |
| standard_path = os.path.join(base, 'models', 'ticket_classifier') | |
| if model_dir is None: | |
| if os.path.exists(os.path.join(ultimate_path, 'config.json')): | |
| self.model_dir = ultimate_path | |
| else: | |
| self.model_dir = standard_path | |
| else: | |
| self.model_dir = model_dir | |
| self._bert_router = None | |
| self._sklearn_pipe = None | |
| self._bert_available = False | |
| # IMPORTANT: Load BERT first and do a warmup pass. | |
| # On Windows, unpickling sklearn before PyTorch's first forward pass | |
| # causes a segfault in torch.distributed/optree DLLs. | |
| self._load_bert(device) | |
| if self._bert_available: | |
| self._warmup_bert() | |
| self._load_sklearn() | |
| try: | |
| from historical_memory import HistoricalMemoryLayer | |
| self._memory_layer = HistoricalMemoryLayer() | |
| except Exception as e: | |
| logger.warning(f"[EnsembleRouter] Could not load Historical Memory Layer: {e}") | |
| self._memory_layer = None | |
| logger.info( | |
| f"[EnsembleRouter] BERT={'ON' if self._bert_available else 'OFF (fallback)'} | " | |
| f"sklearn=ON | weights=({BERT_W}/{SKLEARN_W}) | memory={'ON' if getattr(self, '_memory_layer', None) and self._memory_layer.is_ready else 'OFF'}" | |
| ) | |
| def _warmup_bert(self): | |
| """Perform a warmup forward pass to initialize PyTorch/CUDA state.""" | |
| try: | |
| self._bert_router.mc_predict("warmup", n_passes=1) | |
| logger.info("[EnsembleRouter] BERT warmup complete.") | |
| except Exception as e: | |
| logger.warning(f"[EnsembleRouter] BERT warmup failed: {e}") | |
| # ββ Model loaders ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_sklearn(self): | |
| # Check model_dir first, then fall back to ticket_classifier | |
| pkl = os.path.join(self.model_dir, 'sklearn_router.pkl') | |
| if not os.path.exists(pkl): | |
| base = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| pkl = os.path.join(base, 'models', 'ticket_classifier', 'sklearn_router.pkl') | |
| if not os.path.exists(pkl): | |
| raise FileNotFoundError( | |
| f"sklearn_router.pkl not found. " | |
| "Run: python train_baseline.py" | |
| ) | |
| with open(pkl, 'rb') as f: | |
| self._sklearn_pipe = pickle.load(f) | |
| logger.info(f"[EnsembleRouter] sklearn pipeline loaded from {pkl}.") | |
| def _load_bert(self, device: str): | |
| """Load fine-tuned DistilBERT. Skips gracefully if weights not saved yet.""" | |
| import json, traceback as tb | |
| model_bin = os.path.join(self.model_dir, 'pytorch_model.bin') | |
| model_safe = os.path.join(self.model_dir, 'model.safetensors') | |
| config = os.path.join(self.model_dir, 'config.json') | |
| bert_ready = os.path.exists(config) and ( | |
| os.path.exists(model_bin) or os.path.exists(model_safe) | |
| ) | |
| if not bert_ready: | |
| logger.warning( | |
| "[EnsembleRouter] DistilBERT weights not found β running sklearn-only." | |
| ) | |
| return | |
| # Check for stale baseline stub (only present before first real training run) | |
| try: | |
| with open(config) as f: | |
| cfg = json.load(f) | |
| if cfg.get('model_type') == 'baseline_sklearn': | |
| logger.warning("[EnsembleRouter] config.json is baseline stub β skipping BERT.") | |
| return | |
| except Exception: | |
| pass | |
| try: | |
| from confidence_router import ConfidenceGatedRouter | |
| self._bert_router = ConfidenceGatedRouter(self.model_dir, device=device) | |
| self._bert_available = True | |
| gc.collect() | |
| logger.info(f"[EnsembleRouter] {self._bert_router.model.config.model_type.upper()} loaded successfully.") | |
| except (Exception, OSError) as e: | |
| logger.error(f"[EnsembleRouter] BERT load failed (likely memory constraint): {e}") | |
| # Ensure we don't leave a half-initialized router | |
| self._bert_router = None | |
| self._bert_available = False | |
| gc.collect() | |
| # ββ Prediction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _sklearn_probs(self, text: str) -> np.ndarray: | |
| """Return calibrated probability distribution from sklearn pipeline.""" | |
| return self._sklearn_pipe.predict_proba([text])[0] # shape [8] | |
| def _bert_probs(self, text: str) -> np.ndarray: | |
| """Return MC-Dropout probability distribution from DistilBERT.""" | |
| _, _, _, mean_p, _ = self._bert_router.mc_predict(text, n_passes=MC_PASSES) | |
| return mean_p # shape [8] | |
| def _blend(self, text: str): | |
| """ | |
| Compute blended probability distribution. | |
| Returns: (blended_probs, bert_probs_or_None, sklearn_probs, bert_std_or_None) | |
| """ | |
| sk_probs = self._sklearn_probs(text) | |
| if self._bert_available: | |
| _, _, _, bert_mean, bert_std = self._bert_router.mc_predict(text, MC_PASSES) | |
| blended = BERT_W * bert_mean + SKLEARN_W * sk_probs | |
| # Re-normalise (floating point can drift slightly) | |
| blended = blended / blended.sum() | |
| return blended, bert_mean, sk_probs, bert_std | |
| else: | |
| return sk_probs, None, sk_probs, np.zeros(8) | |
| # ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def route(self, ticket_text: str, n_passes: int = MC_PASSES) -> Dict: | |
| """ | |
| Route a ticket through the ensemble confidence gate. | |
| Returns the same dict schema as ConfidenceGatedRouter.route() | |
| so it is a drop-in replacement in api.py. | |
| """ | |
| blended, bert_p, sk_p, bert_std = self._blend(ticket_text) | |
| confidence = float(blended.max()) | |
| entropy = float(-np.sum(blended * np.log(blended + 1e-9))) | |
| pred_class = int(blended.argmax()) | |
| category = CATEGORY_MAP[pred_class] | |
| # Build ranking | |
| ranking = sorted( | |
| [(CATEGORY_MAP[i], round(float(blended[i]), 4)) for i in range(8)], | |
| key=lambda x: x[1], reverse=True | |
| ) | |
| top_two = [ranking[0][0], ranking[1][0]] | |
| base = { | |
| 'confidence': round(confidence, 4), | |
| 'entropy': round(entropy, 4), | |
| 'top_category': category, | |
| 'all_probs': {CATEGORY_MAP[i]: round(float(blended[i]), 4) for i in range(8)}, | |
| 'std_probs': {CATEGORY_MAP[i]: round(float(bert_std[i]), 4) for i in range(8)}, | |
| 'category_ranking': ranking, | |
| 'top_two_classes': top_two, | |
| 'mc_passes': n_passes, | |
| # Extra ensemble diagnostics | |
| 'ensemble': { | |
| 'bert_available': self._bert_available, | |
| 'bert_top': CATEGORY_MAP[int(bert_p.argmax())] if bert_p is not None else None, | |
| 'sklearn_top': CATEGORY_MAP[int(sk_p.argmax())], | |
| 'bert_weight': BERT_W if self._bert_available else 0.0, | |
| 'sklearn_weight': SKLEARN_W if self._bert_available else 1.0, | |
| 'agreement': ( | |
| CATEGORY_MAP[int(bert_p.argmax())] == CATEGORY_MAP[int(sk_p.argmax())] | |
| if bert_p is not None else True | |
| ), | |
| } | |
| } | |
| top1_score = ranking[0][1] | |
| top2_score = ranking[1][1] | |
| margin = top1_score - top2_score | |
| hist_boost = 0.0 | |
| if getattr(self, '_memory_layer', None) and self._memory_layer.is_ready: | |
| hist_boost = self._memory_layer.compute_historical_boost(ticket_text, category) | |
| base['historical_boost'] = hist_boost | |
| base['margin'] = round(margin, 4) | |
| base['confidence'] = round(confidence, 4) | |
| critical_labels = ['compliance_legal', 'account_management'] | |
| effective_conf = confidence + hist_boost | |
| if category in critical_labels: | |
| if effective_conf >= 0.90 and margin >= 0.35 and entropy < 0.60: | |
| action = 'route' | |
| reason = f'β’ Safe to auto-route sensitive intent<br>β’ Confidence: {confidence:.2%}<br>β’ Margin: {margin:.2f}' | |
| if hist_boost > 0: reason += f'<br>β’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span>' | |
| else: | |
| action = 'escalate' | |
| reason = f'β’ Escalated sensitive intent ({category})<br>β’ Strict confidence/margin threshold not met' | |
| if hist_boost > 0: reason += f'<br>β’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span> (Insufficient)' | |
| elif category == 'technical_support': | |
| # Category-specific check for technical support to catch billing misroutes | |
| billing_keywords = ['invoice', 'billing', 'charge', 'refund', 'payment', 'subscription', 'plan'] | |
| has_billing_kw = any(kw in ticket_text.lower() for kw in billing_keywords) | |
| if has_billing_kw and 'billing' in [r[0] for r in ranking[:3]]: | |
| action = 'clarify' | |
| reason = f'β’ Billing overlap detected<br>β’ Clarification needed between technical_support and billing' | |
| elif effective_conf >= 0.88 and margin >= 0.30 and entropy < 0.65: | |
| # Stricter thresholds for technical_support | |
| action = 'route' | |
| reason = f'β’ Strong dominant intent<br>β’ Confidence: {confidence:.2%}<br>β’ Margin: {margin:.2f}<br>β’ Safe to auto-route' | |
| if hist_boost > 0: reason += f'<br>β’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span>' | |
| elif effective_conf >= 0.60 and entropy < 1.05: | |
| action = 'clarify' | |
| reason = f'β’ Medium ambiguity detected<br>β’ Clarification needed between {top_two[0]} and {top_two[1]}<br>β’ Margin: {margin:.2f}' | |
| if hist_boost > 0: reason += f'<br>β’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span> (Insufficient for auto-route)' | |
| else: | |
| action = 'escalate' | |
| reason = f'β’ High ambiguity / Low confidence ({confidence:.2%})<br>β’ Multiple overlapping intents detected<br>β’ Human triage needed' | |
| else: | |
| if effective_conf >= 0.85 and margin >= 0.25 and entropy < 0.70: | |
| action = 'route' | |
| reason = f'β’ Strong dominant intent<br>β’ Confidence: {confidence:.2%}<br>β’ Margin: {margin:.2f}<br>β’ Safe to auto-route' | |
| if hist_boost > 0: reason += f'<br>β’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span>' | |
| elif effective_conf >= 0.60 and entropy < 1.05: | |
| action = 'clarify' | |
| reason = f'β’ Medium ambiguity detected<br>β’ Clarification needed between {top_two[0]} and {top_two[1]}<br>β’ Margin: {margin:.2f}' | |
| if hist_boost > 0: reason += f'<br>β’ <span style="color:var(--green)">Historical Match Boost: +{hist_boost:.2%}</span> (Insufficient for auto-route)' | |
| else: | |
| action = 'escalate' | |
| reason = f'β’ High ambiguity / Low confidence ({confidence:.2%})<br>β’ Multiple overlapping intents detected<br>β’ Human triage needed' | |
| return {**base, 'action': action, 'queue': category if action == 'route' else None, 'reason': reason} | |
| def batch_route(self, tickets: list, n_passes: int = MC_PASSES) -> list: | |
| return [self.route(t, n_passes) for t in tickets] | |
| # Property to expose model/tokenizer for the SHAP explainer in api.py | |
| def model(self): | |
| if self._bert_available: | |
| return self._bert_router.model | |
| return None | |
| def tokenizer(self): | |
| if self._bert_available: | |
| return self._bert_router.tokenizer | |
| return None | |
| if __name__ == '__main__': | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') | |
| router = EnsembleRouter() | |
| tests = [ | |
| "My invoice from last month is incorrect, please fix the billing.", | |
| "The API keeps returning 500 errors since last Tuesday's update.", | |
| "I want to cancel β this tool has been broken for weeks.", | |
| "How do I add another user to our account?", | |
| "We need GDPR data processing agreements for our EU customers.", | |
| "Not happy at all, considering switching to a competitor.", | |
| "Can you add a dark mode to the dashboard?", | |
| "Just signed up β how do I import my existing data?", | |
| # Tricky ambiguous cases | |
| "Invoice is wrong AND the app keeps crashing.", | |
| "Not happy with service", | |
| ] | |
| print(f"\n{'='*90}") | |
| print(f" SupportMind Ensemble Router β BERT={'ON' if router._bert_available else 'OFF (sklearn only)'}") | |
| print(f"{'='*90}\n") | |
| for ticket in tests: | |
| r = router.route(ticket) | |
| agree = 'AGREE' if r['ensemble']['agreement'] else 'DISAGREE' | |
| print( | |
| f"[{r['action'].upper():8s}] [{r['confidence']:.2%}] " | |
| f"{'H' if r['entropy'] < ENTROPY_MAX else 'L'}-certainty | " | |
| f"{r['top_category']:20s} | " | |
| f"Models: {agree} | {ticket[:60]}" | |
| ) | |