philschmid HF staff commited on
Commit
654c5de
1 Parent(s): 902d14f

Create new file

Browse files
Files changed (1) hide show
  1. handler.py +47 -0
handler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from app.pipelines import Pipeline
4
+ from flair.data import Sentence
5
+ from flair.models import SequenceTagger
6
+
7
+ class EndpointHandler():
8
+ def __init__(
9
+ self,
10
+ model_id: str,
11
+ ):
12
+ self.tagger = SequenceTagger.load(model_id)
13
+
14
+ def __call__(self, inputs: str) -> List[Dict[str, Any]]:
15
+ """
16
+ Args:
17
+ inputs (:obj:`str`):
18
+ a string containing some text
19
+ Return:
20
+ A :obj:`list`:. The object returned should be like [{"entity_group": "XXX", "word": "some word", "start": 3, "end": 6, "score": 0.82}] containing :
21
+ - "entity_group": A string representing what the entity is.
22
+ - "word": A substring of the original string that was detected as an entity.
23
+ - "start": the offset within `input` leading to `answer`. context[start:stop] == word
24
+ - "end": the ending offset within `input` leading to `answer`. context[start:stop] === word
25
+ - "score": A score between 0 and 1 describing how confident the model is for this entity.
26
+ """
27
+ inputs = data.pop("inputs", data)
28
+ sentence: Sentence = Sentence(inputs)
29
+
30
+ # Also show scores for recognized NEs
31
+ self.tagger.predict(sentence, label_name="predicted")
32
+
33
+ entities = []
34
+ for span in sentence.get_spans("predicted"):
35
+ if len(span.tokens) == 0:
36
+ continue
37
+ current_entity = {
38
+ "entity_group": span.tag,
39
+ "word": span.text,
40
+ "start": span.tokens[0].start_position,
41
+ "end": span.tokens[-1].end_position,
42
+ "score": span.score,
43
+ }
44
+
45
+ entities.append(current_entity)
46
+
47
+ return entities