leblanciii commited on
Commit
5dfadff
·
verified ·
1 Parent(s): 0127207

Upload 6-head multi-task classifier (peril+severity+category+fire_sub+relevance+actionability)

Browse files
Files changed (3) hide show
  1. handler.py +22 -7
  2. label_config.json +6 -0
  3. pytorch_model.bin +2 -2
handler.py CHANGED
@@ -6,9 +6,10 @@ PERIL_LABELS = ["fire","flood","named_windstorm","construction_theft","transient
6
  SEVERITY_LABELS = ["low","medium","high","critical"]
7
  CATEGORY_LABELS = ["incident_report","trend","regulatory","research","warning"]
8
  FIRE_SUBCATEGORY_LABELS = ["arson","wildfire","unknown_cause"]
 
9
 
10
  class MultiTaskClassifier(nn.Module):
11
- def __init__(self, model_name, np, ns, nc, nf):
12
  super().__init__()
13
  self.encoder = AutoModel.from_pretrained(model_name)
14
  h = self.encoder.config.hidden_size
@@ -17,17 +18,20 @@ class MultiTaskClassifier(nn.Module):
17
  self.severity_head = nn.Linear(h, ns)
18
  self.category_head = nn.Linear(h, nc)
19
  self.fire_sub_head = nn.Linear(h, nf)
 
 
20
  def forward(self, input_ids, attention_mask=None):
21
  out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
22
  pooled = self.dropout(out.last_hidden_state[:, 0, :])
23
  return {"peril_logits": self.peril_head(pooled), "severity_logits": self.severity_head(pooled),
24
- "category_logits": self.category_head(pooled), "fire_sub_logits": self.fire_sub_head(pooled)}
 
25
 
26
  class EndpointHandler:
27
  def __init__(self, path=""):
28
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  self.tokenizer = AutoTokenizer.from_pretrained(path)
30
- self.model = MultiTaskClassifier(path, len(PERIL_LABELS), len(SEVERITY_LABELS), len(CATEGORY_LABELS), len(FIRE_SUBCATEGORY_LABELS))
31
  w = Path(path) / "pytorch_model.bin"
32
  if w.exists():
33
  self.model.load_state_dict(torch.load(str(w), map_location=self.device, weights_only=True))
@@ -36,6 +40,9 @@ class EndpointHandler:
36
  def __call__(self, data):
37
  text = data.get("inputs", "")
38
  if isinstance(text, list): text = text[0]
 
 
 
39
  inputs = self.tokenizer(text[:16000], truncation=True, max_length=512, return_tensors="pt").to(self.device)
40
  with torch.no_grad():
41
  out = self.model(**inputs)
@@ -46,7 +53,15 @@ class EndpointHandler:
46
  ci = int(out["category_logits"].argmax(-1)[0].cpu())
47
  fp = torch.softmax(out["fire_sub_logits"], -1)[0].cpu().tolist()
48
  fi = int(out["fire_sub_logits"].argmax(-1)[0].cpu())
49
- return {"peril_scores": {l: round(s,4) for l,s in zip(PERIL_LABELS, pp)},
50
- "severity": {"label": SEVERITY_LABELS[si], "confidence": round(sp[si],4)},
51
- "category": {"label": CATEGORY_LABELS[ci], "confidence": round(cp[ci],4)},
52
- "fire_subcategory": {"label": FIRE_SUBCATEGORY_LABELS[fi], "confidence": round(fp[fi],4)}}
 
 
 
 
 
 
 
 
 
6
  SEVERITY_LABELS = ["low","medium","high","critical"]
7
  CATEGORY_LABELS = ["incident_report","trend","regulatory","research","warning"]
8
  FIRE_SUBCATEGORY_LABELS = ["arson","wildfire","unknown_cause"]
9
+ ACTIONABILITY_LABELS = ["irrelevant","informational","notable","actionable"]
10
 
11
  class MultiTaskClassifier(nn.Module):
12
+ def __init__(self, model_name, np, ns, nc, nf, na=4):
13
  super().__init__()
14
  self.encoder = AutoModel.from_pretrained(model_name)
15
  h = self.encoder.config.hidden_size
 
18
  self.severity_head = nn.Linear(h, ns)
19
  self.category_head = nn.Linear(h, nc)
20
  self.fire_sub_head = nn.Linear(h, nf)
21
+ self.relevance_head = nn.Linear(h, 1)
22
+ self.actionability_head = nn.Linear(h, na)
23
  def forward(self, input_ids, attention_mask=None):
24
  out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
25
  pooled = self.dropout(out.last_hidden_state[:, 0, :])
26
  return {"peril_logits": self.peril_head(pooled), "severity_logits": self.severity_head(pooled),
27
+ "category_logits": self.category_head(pooled), "fire_sub_logits": self.fire_sub_head(pooled),
28
+ "relevance_logits": self.relevance_head(pooled), "actionability_logits": self.actionability_head(pooled)}
29
 
30
  class EndpointHandler:
31
  def __init__(self, path=""):
32
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  self.tokenizer = AutoTokenizer.from_pretrained(path)
34
+ self.model = MultiTaskClassifier(path, len(PERIL_LABELS), len(SEVERITY_LABELS), len(CATEGORY_LABELS), len(FIRE_SUBCATEGORY_LABELS), len(ACTIONABILITY_LABELS))
35
  w = Path(path) / "pytorch_model.bin"
36
  if w.exists():
37
  self.model.load_state_dict(torch.load(str(w), map_location=self.device, weights_only=True))
 
40
  def __call__(self, data):
41
  text = data.get("inputs", "")
42
  if isinstance(text, list): text = text[0]
43
+ params = data.get("parameters", {})
44
+ include_relevance = params.get("include_relevance", False)
45
+ include_actionability = params.get("include_actionability", False)
46
  inputs = self.tokenizer(text[:16000], truncation=True, max_length=512, return_tensors="pt").to(self.device)
47
  with torch.no_grad():
48
  out = self.model(**inputs)
 
53
  ci = int(out["category_logits"].argmax(-1)[0].cpu())
54
  fp = torch.softmax(out["fire_sub_logits"], -1)[0].cpu().tolist()
55
  fi = int(out["fire_sub_logits"].argmax(-1)[0].cpu())
56
+ result = {"peril_scores": {l: round(s,4) for l,s in zip(PERIL_LABELS, pp)},
57
+ "severity": {"label": SEVERITY_LABELS[si], "confidence": round(sp[si],4)},
58
+ "category": {"label": CATEGORY_LABELS[ci], "confidence": round(cp[ci],4)},
59
+ "fire_subcategory": {"label": FIRE_SUBCATEGORY_LABELS[fi], "confidence": round(fp[fi],4)}}
60
+ if include_relevance:
61
+ rel_score = float(torch.sigmoid(out["relevance_logits"].squeeze(-1))[0].cpu())
62
+ result["relevance"] = {"score": round(rel_score, 4), "label": "relevant" if rel_score > 0.5 else "irrelevant"}
63
+ if include_actionability:
64
+ ap = torch.softmax(out["actionability_logits"], -1)[0].cpu().tolist()
65
+ ai = int(out["actionability_logits"].argmax(-1)[0].cpu())
66
+ result["actionability"] = {"label": ACTIONABILITY_LABELS[ai], "confidence": round(ap[ai], 4)}
67
+ return result
label_config.json CHANGED
@@ -25,5 +25,11 @@
25
  "0": "arson",
26
  "1": "wildfire",
27
  "2": "unknown_cause"
 
 
 
 
 
 
28
  }
29
  }
 
25
  "0": "arson",
26
  "1": "wildfire",
27
  "2": "unknown_cause"
28
+ },
29
+ "actionability_labels": {
30
+ "0": "irrelevant",
31
+ "1": "informational",
32
+ "2": "notable",
33
+ "3": "actionable"
34
  }
35
  }
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2bd721a7ff60be333ebd50cc28b86a2eab55037ab3fdc135e04150a778ee57f7
3
- size 498727455
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33da6dbfb4dc60916ae56804d7fe4affe89b7fd456fc89aa652cdc7e34c6cdf8
3
+ size 498744175