philschmid's picture
philschmid HF staff
Update handler.py
70e8d2e
raw history blame
No virus
1.81 kB
from typing import Any, Dict, List
import os
from flair.data import Sentence
from flair.models import SequenceTagger
class EndpointHandler():
def __init__(
self,
path: str,
):
self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`str`):
a string containing some text
Return:
A :obj:`list`:. The object returned should be like [{"entity_group": "XXX", "word": "some word", "start": 3, "end": 6, "score": 0.82}] containing :
- "entity_group": A string representing what the entity is.
- "word": A substring of the original string that was detected as an entity.
- "start": the offset within `input` leading to `answer`. context[start:stop] == word
- "end": the ending offset within `input` leading to `answer`. context[start:stop] === word
- "score": A score between 0 and 1 describing how confident the model is for this entity.
"""
inputs = data.pop("inputs", data)
sentence: Sentence = Sentence(inputs)
# Also show scores for recognized NEs
self.tagger.predict(sentence, label_name="predicted")
entities = []
for span in sentence.get_spans("predicted"):
if len(span.tokens) == 0:
continue
current_entity = {
"entity_group": span.tag,
"word": span.text,
"start": span.tokens[0].start_position,
"end": span.tokens[-1].end_position,
"score": span.score,
}
entities.append(current_entity)
return entities