flore2003 commited on
Commit
10a1b40
1 Parent(s): 9e5dc17

feat: Adds handler.py for inference endpoints

Browse files
Files changed (3) hide show
  1. README.md +5 -0
  2. handler.py +46 -0
  3. requirements.txt +1 -0
README.md CHANGED
@@ -15,6 +15,11 @@ widget:
15
  - text: "George Washington ging nach Washington"
16
  ---
17
 
 
 
 
 
 
18
  ## 4-Language NER in Flair (English, German, Dutch and Spanish)
19
 
20
  This is the standard 4-class NER model for 4 CoNLL-03 languages that ships with [Flair](https://github.com/flairNLP/flair/). Also kind of works for related languages like French.
 
15
  - text: "George Washington ging nach Washington"
16
  ---
17
 
18
+ # This is a fork of flair/ner-multi
19
+
20
+ As `flair/ner-multi` is missing a `handler.py`, this form implements a custom `handler.py` to be used with inference endpoints. The original model can be found here: [https://huggingface.co/flair/ner-multi](https://huggingface.co/flair/ner-multi)
21
+
22
+
23
  ## 4-Language NER in Flair (English, German, Dutch and Spanish)
24
 
25
  This is the standard 4-class NER model for 4 CoNLL-03 languages that ships with [Flair](https://github.com/flairNLP/flair/). Also kind of works for related languages like French.
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import os
3
+ from flair.data import Sentence
4
+ from flair.models import SequenceTagger
5
+
6
+ class EndpointHandler():
7
+ def __init__(
8
+ self,
9
+ path: str,
10
+ ):
11
+ self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))
12
+
13
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
+ """
15
+ Args:
16
+ inputs (:obj:`str`):
17
+ a string containing some text
18
+ Return:
19
+ A :obj:`list`:. The object returned should be like [{"entity_group": "XXX", "word": "some word", "start": 3, "end": 6, "score": 0.82}] containing :
20
+ - "entity_group": A string representing what the entity is.
21
+ - "word": A substring of the original string that was detected as an entity.
22
+ - "start": the offset within `input` leading to `answer`. context[start:stop] == word
23
+ - "end": the ending offset within `input` leading to `answer`. context[start:stop] === word
24
+ - "score": A score between 0 and 1 describing how confident the model is for this entity.
25
+ """
26
+ inputs = data.pop("inputs", data)
27
+ sentence: Sentence = Sentence(inputs)
28
+
29
+ # Also show scores for recognized NEs
30
+ self.tagger.predict(sentence, label_name="predicted")
31
+
32
+ entities = []
33
+ for span in sentence.get_spans("predicted"):
34
+ if len(span.tokens) == 0:
35
+ continue
36
+ current_entity = {
37
+ "entity_group": span.tag,
38
+ "word": span.text,
39
+ "start": span.tokens[0].start_position,
40
+ "end": span.tokens[-1].end_position,
41
+ "score": span.score,
42
+ }
43
+
44
+ entities.append(current_entity)
45
+
46
+ return entities
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ flair==0.12.2