Nottybro commited on
Commit
90f2a5f
·
verified ·
1 Parent(s): 30c982b

deploy: classifier_inference.py

Browse files
Files changed (1) hide show
  1. classifier_inference.py +34 -0
classifier_inference.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
3
+
4
+ MODEL_ID = "Nottybro/acra-classifier"
5
+ LABEL_NAMES = ["L0_direct", "L1_single_hop", "L2_multi_hop", "L3_complex"]
6
+ _tok = None
7
+ _mdl = None
8
+
9
+ def _load():
10
+ global _tok, _mdl
11
+ if _mdl is None:
12
+ print(f"Loading classifier from {MODEL_ID}...")
13
+ _tok = DistilBertTokenizerFast.from_pretrained(MODEL_ID)
14
+ _mdl = DistilBertForSequenceClassification.from_pretrained(MODEL_ID)
15
+ _mdl.eval()
16
+
17
+ def warm_up():
18
+ _load()
19
+ classify_query("what is the capital of france")
20
+ print("Classifier warm ✓")
21
+
22
+ def classify_query(query: str) -> dict:
23
+ _load()
24
+ enc = _tok(query, max_length=128, padding="max_length",
25
+ truncation=True, return_tensors="pt")
26
+ with torch.no_grad():
27
+ probs = torch.softmax(_mdl(**enc).logits, dim=-1).squeeze()
28
+ level = int(probs.argmax())
29
+ return {
30
+ "level": level,
31
+ "label": LABEL_NAMES[level],
32
+ "confidence": round(probs[level].item(), 4),
33
+ "scores": {f"L{i}": round(p.item(), 4) for i, p in enumerate(probs)}
34
+ }