olgen commited on
Commit
2200677
1 Parent(s): d876181

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +46 -0
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import os
3
+ from flair.data import Sentence
4
+ from flair.models import SequenceTagger
5
+
6
+ class EndpointHandler():
7
+ def __init__(
8
+ self,
9
+ path: str,
10
+ ):
11
+ self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))
12
+
13
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
+ """
15
+ Args:
16
+ inputs (:obj:`str`):
17
+ a string containing some text
18
+ Return:
19
+ A :obj:`list`:. The object returned should be like [{"entity_group": "XXX", "word": "some word", "start": 3, "end": 6, "score": 0.82}] containing :
20
+ - "entity_group": A string representing what the entity is.
21
+ - "word": A substring of the original string that was detected as an entity.
22
+ - "start": the offset within `input` leading to `answer`. context[start:stop] == word
23
+ - "end": the ending offset within `input` leading to `answer`. context[start:stop] === word
24
+ - "score": A score between 0 and 1 describing how confident the model is for this entity.
25
+ """
26
+ inputs = data.pop("inputs", data)
27
+ sentence: Sentence = Sentence(inputs)
28
+
29
+ # Also show scores for recognized NEs
30
+ self.tagger.predict(sentence, label_name="predicted")
31
+
32
+ entities = []
33
+ for span in sentence.get_spans("predicted"):
34
+ if len(span.tokens) == 0:
35
+ continue
36
+ current_entity = {
37
+ "entity_group": span.tag,
38
+ "word": span.text,
39
+ "start": span.tokens[0].start_position,
40
+ "end": span.tokens[-1].end_position,
41
+ "score": span.score,
42
+ }
43
+
44
+ entities.append(current_entity)
45
+
46
+ return entities