Spaces:
Sleeping
Sleeping
| """ | |
| MamaGuard — Explainability | |
| Gradient-based attribution for per-feature and per-visit importance. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from src.model import MamaGuardMamba3 | |
| FEATURE_NAMES = ['Age', 'SystolicBP', 'DiastolicBP', 'BloodSugar', 'BodyTemp', 'HeartRate'] | |
| RISK_LABELS = ['Low risk', 'Medium risk', 'High risk'] | |
| RISK_EMOJI = ['🟢', '🟡', '🔴'] | |
| def explain_prediction( | |
| model: MamaGuardMamba3, | |
| x_sequence: np.ndarray, | |
| scaler, | |
| device: str = "cpu" | |
| ) -> dict: | |
| """ | |
| Runs the model on one patient and returns an explanation dict with: | |
| risk_level, probabilities, confidence, top_reasons, | |
| feature_importance, and visit_importance. | |
| """ | |
| model.eval() | |
| x_tensor = torch.tensor(x_sequence, dtype=torch.float32).unsqueeze(0).to(device) | |
| x_tensor.requires_grad_(True) | |
| # Forward pass | |
| logits = model(x_tensor) | |
| probs = F.softmax(logits, dim=-1) | |
| pred_class = probs.argmax(dim=-1).item() | |
| confidence = probs[0, pred_class].item() | |
| # Gradient attribution | |
| score = logits[0, pred_class] | |
| score.backward() | |
| grads = x_tensor.grad[0].cpu().numpy() | |
| attribution = np.abs(grads) | |
| # Per-feature importance (average over visits) | |
| feature_importance = attribution.mean(axis=0) | |
| feature_importance = feature_importance / (feature_importance.sum() + 1e-9) | |
| # Per-visit importance (average over features) | |
| visit_importance = attribution.mean(axis=1) | |
| visit_importance = visit_importance / (visit_importance.sum() + 1e-9) | |
| # Build human-readable top reasons | |
| sorted_features = sorted( | |
| zip(FEATURE_NAMES, feature_importance), | |
| key=lambda x: x[1], reverse=True | |
| ) | |
| top_reasons = [] | |
| x_orig = scaler.inverse_transform(x_sequence) | |
| for feat, importance in sorted_features[:2]: | |
| feat_idx = FEATURE_NAMES.index(feat) | |
| vals = x_orig[:, feat_idx] | |
| trend = vals[-1] - vals[0] | |
| if abs(trend) > 0.5: | |
| direction = "rising" if trend > 0 else "falling" | |
| top_reasons.append( | |
| f"{feat} is {direction} (from {vals[0]:.1f} to {vals[-1]:.1f})" | |
| ) | |
| else: | |
| top_reasons.append( | |
| f"{feat} is consistently elevated (avg {vals.mean():.1f})" | |
| ) | |
| # Assemble result | |
| probs_np = probs[0].detach().cpu().numpy() | |
| return { | |
| "risk_level": RISK_LABELS[pred_class], | |
| "risk_emoji": RISK_EMOJI[pred_class], | |
| "probabilities": { | |
| label: round(float(p), 4) | |
| for label, p in zip(RISK_LABELS, probs_np) | |
| }, | |
| "confidence": round(confidence, 4), | |
| "top_reasons": top_reasons, | |
| "feature_importance": { | |
| feat: round(float(imp), 4) | |
| for feat, imp in zip(FEATURE_NAMES, feature_importance) | |
| }, | |
| "visit_importance": [round(float(v), 4) for v in visit_importance], | |
| } |