Upload 6-head multi-task classifier (peril+severity+category+fire_sub+relevance+actionability)
5dfadff verified | import json, torch, torch.nn as nn | |
| from pathlib import Path | |
| from transformers import AutoModel, AutoTokenizer | |
| PERIL_LABELS = ["fire","flood","named_windstorm","construction_theft","transient_population","civil_unrest","earthquake"] | |
| SEVERITY_LABELS = ["low","medium","high","critical"] | |
| CATEGORY_LABELS = ["incident_report","trend","regulatory","research","warning"] | |
| FIRE_SUBCATEGORY_LABELS = ["arson","wildfire","unknown_cause"] | |
| ACTIONABILITY_LABELS = ["irrelevant","informational","notable","actionable"] | |
| class MultiTaskClassifier(nn.Module): | |
| def __init__(self, model_name, np, ns, nc, nf, na=4): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| h = self.encoder.config.hidden_size | |
| self.dropout = nn.Dropout(0.1) | |
| self.peril_head = nn.Linear(h, np) | |
| self.severity_head = nn.Linear(h, ns) | |
| self.category_head = nn.Linear(h, nc) | |
| self.fire_sub_head = nn.Linear(h, nf) | |
| self.relevance_head = nn.Linear(h, 1) | |
| self.actionability_head = nn.Linear(h, na) | |
| def forward(self, input_ids, attention_mask=None): | |
| out = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled = self.dropout(out.last_hidden_state[:, 0, :]) | |
| return {"peril_logits": self.peril_head(pooled), "severity_logits": self.severity_head(pooled), | |
| "category_logits": self.category_head(pooled), "fire_sub_logits": self.fire_sub_head(pooled), | |
| "relevance_logits": self.relevance_head(pooled), "actionability_logits": self.actionability_head(pooled)} | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| self.model = MultiTaskClassifier(path, len(PERIL_LABELS), len(SEVERITY_LABELS), len(CATEGORY_LABELS), len(FIRE_SUBCATEGORY_LABELS), len(ACTIONABILITY_LABELS)) | |
| w = Path(path) / "pytorch_model.bin" | |
| if w.exists(): | |
| self.model.load_state_dict(torch.load(str(w), map_location=self.device, weights_only=True)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def __call__(self, data): | |
| text = data.get("inputs", "") | |
| if isinstance(text, list): text = text[0] | |
| params = data.get("parameters", {}) | |
| include_relevance = params.get("include_relevance", False) | |
| include_actionability = params.get("include_actionability", False) | |
| inputs = self.tokenizer(text[:16000], truncation=True, max_length=512, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| out = self.model(**inputs) | |
| pp = torch.sigmoid(out["peril_logits"])[0].cpu().tolist() | |
| sp = torch.softmax(out["severity_logits"], -1)[0].cpu().tolist() | |
| si = int(out["severity_logits"].argmax(-1)[0].cpu()) | |
| cp = torch.softmax(out["category_logits"], -1)[0].cpu().tolist() | |
| ci = int(out["category_logits"].argmax(-1)[0].cpu()) | |
| fp = torch.softmax(out["fire_sub_logits"], -1)[0].cpu().tolist() | |
| fi = int(out["fire_sub_logits"].argmax(-1)[0].cpu()) | |
| result = {"peril_scores": {l: round(s,4) for l,s in zip(PERIL_LABELS, pp)}, | |
| "severity": {"label": SEVERITY_LABELS[si], "confidence": round(sp[si],4)}, | |
| "category": {"label": CATEGORY_LABELS[ci], "confidence": round(cp[ci],4)}, | |
| "fire_subcategory": {"label": FIRE_SUBCATEGORY_LABELS[fi], "confidence": round(fp[fi],4)}} | |
| if include_relevance: | |
| rel_score = float(torch.sigmoid(out["relevance_logits"].squeeze(-1))[0].cpu()) | |
| result["relevance"] = {"score": round(rel_score, 4), "label": "relevant" if rel_score > 0.5 else "irrelevant"} | |
| if include_actionability: | |
| ap = torch.softmax(out["actionability_logits"], -1)[0].cpu().tolist() | |
| ai = int(out["actionability_logits"].argmax(-1)[0].cpu()) | |
| result["actionability"] = {"label": ACTIONABILITY_LABELS[ai], "confidence": round(ap[ai], 4)} | |
| return result | |