tomaarsen HF staff commited on
Commit
f989676
1 Parent(s): e1eba79

Add handler.py to support inference endpoints

Browse files
Files changed (1) hide show
  1. handler.py +33 -0
handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from span_marker import SpanMarkerModel
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, model_id: str) -> None:
8
+ self.model = SpanMarkerModel.from_pretrained(model_id)
9
+ # Try to place it on CUDA, do nothing if it fails
10
+ self.model.try_cuda()
11
+
12
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
+ """
14
+ Args:
15
+ data (Dict[str, Any]):
16
+ a dictionary with the "inputs" key corresponding to a string containing some text
17
+ Return:
18
+ A List[Dict[str, Any]]:. The object returned should be like [{"entity_group": "XXX", "word": "some word", "start": 3, "end": 6, "score": 0.82}] containing :
19
+ - "entity_group": A string representing what the entity is.
20
+ - "word": A rubstring of the original string that was detected as an entity.
21
+ - "start": the offset within `input` leading to `answer`. context[start:stop] == word
22
+ - "end": the ending offset within `input` leading to `answer`. context[start:stop] === word
23
+ - "score": A score between 0 and 1 describing how confident the model is for this entity.
24
+ """
25
+ return [
26
+ {
27
+ "entity_group": entity["label"],
28
+ "word": entity["span"],
29
+ "start": entity["char_start_index"],
30
+ "end": entity["char_end_index"],
31
+ "score": entity["score"],
32
+ }
33
+ for entity in self.model.predict(data["inputs"])