Upload 6-head multi-task classifier (peril+severity+category+fire_sub+relevance+actionability)
Browse files- handler.py +22 -7
- label_config.json +6 -0
- pytorch_model.bin +2 -2
handler.py
CHANGED
|
@@ -6,9 +6,10 @@ PERIL_LABELS = ["fire","flood","named_windstorm","construction_theft","transient
|
|
| 6 |
SEVERITY_LABELS = ["low","medium","high","critical"]
|
| 7 |
CATEGORY_LABELS = ["incident_report","trend","regulatory","research","warning"]
|
| 8 |
FIRE_SUBCATEGORY_LABELS = ["arson","wildfire","unknown_cause"]
|
|
|
|
| 9 |
|
| 10 |
class MultiTaskClassifier(nn.Module):
|
| 11 |
-
def __init__(self, model_name, np, ns, nc, nf):
|
| 12 |
super().__init__()
|
| 13 |
self.encoder = AutoModel.from_pretrained(model_name)
|
| 14 |
h = self.encoder.config.hidden_size
|
|
@@ -17,17 +18,20 @@ class MultiTaskClassifier(nn.Module):
|
|
| 17 |
self.severity_head = nn.Linear(h, ns)
|
| 18 |
self.category_head = nn.Linear(h, nc)
|
| 19 |
self.fire_sub_head = nn.Linear(h, nf)
|
|
|
|
|
|
|
| 20 |
def forward(self, input_ids, attention_mask=None):
|
| 21 |
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 22 |
pooled = self.dropout(out.last_hidden_state[:, 0, :])
|
| 23 |
return {"peril_logits": self.peril_head(pooled), "severity_logits": self.severity_head(pooled),
|
| 24 |
-
"category_logits": self.category_head(pooled), "fire_sub_logits": self.fire_sub_head(pooled)
|
|
|
|
| 25 |
|
| 26 |
class EndpointHandler:
|
| 27 |
def __init__(self, path=""):
|
| 28 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 30 |
-
self.model = MultiTaskClassifier(path, len(PERIL_LABELS), len(SEVERITY_LABELS), len(CATEGORY_LABELS), len(FIRE_SUBCATEGORY_LABELS))
|
| 31 |
w = Path(path) / "pytorch_model.bin"
|
| 32 |
if w.exists():
|
| 33 |
self.model.load_state_dict(torch.load(str(w), map_location=self.device, weights_only=True))
|
|
@@ -36,6 +40,9 @@ class EndpointHandler:
|
|
| 36 |
def __call__(self, data):
|
| 37 |
text = data.get("inputs", "")
|
| 38 |
if isinstance(text, list): text = text[0]
|
|
|
|
|
|
|
|
|
|
| 39 |
inputs = self.tokenizer(text[:16000], truncation=True, max_length=512, return_tensors="pt").to(self.device)
|
| 40 |
with torch.no_grad():
|
| 41 |
out = self.model(**inputs)
|
|
@@ -46,7 +53,15 @@ class EndpointHandler:
|
|
| 46 |
ci = int(out["category_logits"].argmax(-1)[0].cpu())
|
| 47 |
fp = torch.softmax(out["fire_sub_logits"], -1)[0].cpu().tolist()
|
| 48 |
fi = int(out["fire_sub_logits"].argmax(-1)[0].cpu())
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
SEVERITY_LABELS = ["low","medium","high","critical"]
|
| 7 |
CATEGORY_LABELS = ["incident_report","trend","regulatory","research","warning"]
|
| 8 |
FIRE_SUBCATEGORY_LABELS = ["arson","wildfire","unknown_cause"]
|
| 9 |
+
ACTIONABILITY_LABELS = ["irrelevant","informational","notable","actionable"]
|
| 10 |
|
| 11 |
class MultiTaskClassifier(nn.Module):
|
| 12 |
+
def __init__(self, model_name, np, ns, nc, nf, na=4):
|
| 13 |
super().__init__()
|
| 14 |
self.encoder = AutoModel.from_pretrained(model_name)
|
| 15 |
h = self.encoder.config.hidden_size
|
|
|
|
| 18 |
self.severity_head = nn.Linear(h, ns)
|
| 19 |
self.category_head = nn.Linear(h, nc)
|
| 20 |
self.fire_sub_head = nn.Linear(h, nf)
|
| 21 |
+
self.relevance_head = nn.Linear(h, 1)
|
| 22 |
+
self.actionability_head = nn.Linear(h, na)
|
| 23 |
def forward(self, input_ids, attention_mask=None):
|
| 24 |
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 25 |
pooled = self.dropout(out.last_hidden_state[:, 0, :])
|
| 26 |
return {"peril_logits": self.peril_head(pooled), "severity_logits": self.severity_head(pooled),
|
| 27 |
+
"category_logits": self.category_head(pooled), "fire_sub_logits": self.fire_sub_head(pooled),
|
| 28 |
+
"relevance_logits": self.relevance_head(pooled), "actionability_logits": self.actionability_head(pooled)}
|
| 29 |
|
| 30 |
class EndpointHandler:
|
| 31 |
def __init__(self, path=""):
|
| 32 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 34 |
+
self.model = MultiTaskClassifier(path, len(PERIL_LABELS), len(SEVERITY_LABELS), len(CATEGORY_LABELS), len(FIRE_SUBCATEGORY_LABELS), len(ACTIONABILITY_LABELS))
|
| 35 |
w = Path(path) / "pytorch_model.bin"
|
| 36 |
if w.exists():
|
| 37 |
self.model.load_state_dict(torch.load(str(w), map_location=self.device, weights_only=True))
|
|
|
|
| 40 |
def __call__(self, data):
|
| 41 |
text = data.get("inputs", "")
|
| 42 |
if isinstance(text, list): text = text[0]
|
| 43 |
+
params = data.get("parameters", {})
|
| 44 |
+
include_relevance = params.get("include_relevance", False)
|
| 45 |
+
include_actionability = params.get("include_actionability", False)
|
| 46 |
inputs = self.tokenizer(text[:16000], truncation=True, max_length=512, return_tensors="pt").to(self.device)
|
| 47 |
with torch.no_grad():
|
| 48 |
out = self.model(**inputs)
|
|
|
|
| 53 |
ci = int(out["category_logits"].argmax(-1)[0].cpu())
|
| 54 |
fp = torch.softmax(out["fire_sub_logits"], -1)[0].cpu().tolist()
|
| 55 |
fi = int(out["fire_sub_logits"].argmax(-1)[0].cpu())
|
| 56 |
+
result = {"peril_scores": {l: round(s,4) for l,s in zip(PERIL_LABELS, pp)},
|
| 57 |
+
"severity": {"label": SEVERITY_LABELS[si], "confidence": round(sp[si],4)},
|
| 58 |
+
"category": {"label": CATEGORY_LABELS[ci], "confidence": round(cp[ci],4)},
|
| 59 |
+
"fire_subcategory": {"label": FIRE_SUBCATEGORY_LABELS[fi], "confidence": round(fp[fi],4)}}
|
| 60 |
+
if include_relevance:
|
| 61 |
+
rel_score = float(torch.sigmoid(out["relevance_logits"].squeeze(-1))[0].cpu())
|
| 62 |
+
result["relevance"] = {"score": round(rel_score, 4), "label": "relevant" if rel_score > 0.5 else "irrelevant"}
|
| 63 |
+
if include_actionability:
|
| 64 |
+
ap = torch.softmax(out["actionability_logits"], -1)[0].cpu().tolist()
|
| 65 |
+
ai = int(out["actionability_logits"].argmax(-1)[0].cpu())
|
| 66 |
+
result["actionability"] = {"label": ACTIONABILITY_LABELS[ai], "confidence": round(ap[ai], 4)}
|
| 67 |
+
return result
|
label_config.json
CHANGED
|
@@ -25,5 +25,11 @@
|
|
| 25 |
"0": "arson",
|
| 26 |
"1": "wildfire",
|
| 27 |
"2": "unknown_cause"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
}
|
| 29 |
}
|
|
|
|
| 25 |
"0": "arson",
|
| 26 |
"1": "wildfire",
|
| 27 |
"2": "unknown_cause"
|
| 28 |
+
},
|
| 29 |
+
"actionability_labels": {
|
| 30 |
+
"0": "irrelevant",
|
| 31 |
+
"1": "informational",
|
| 32 |
+
"2": "notable",
|
| 33 |
+
"3": "actionable"
|
| 34 |
}
|
| 35 |
}
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33da6dbfb4dc60916ae56804d7fe4affe89b7fd456fc89aa652cdc7e34c6cdf8
|
| 3 |
+
size 498744175
|