Jenny Pereira
changes in testing 2
037aa2f
raw
history blame contribute delete
No virus
990 Bytes
from typing import Dict, List, Any
import os
from flair.data import Sentence
from flair.models import SequenceTagger
class EndpointHandler():
def __init__(self, path=str):
#code
self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
#code
inputs = data.pop("inputs", data)
sentence: Sentence = Sentence(inputs)
self.tagger.predict(sentence, label_name="predicted")
entities = []
for span in sentence.get_spans("predicted"):
if len(span.tokens) == 0:
continue
current_entity = {
"entity_group": span.tag,
"word": span.text,
"start": span.tokens[0].start_position,
"end": span.tokens[-1].end_position,
"score": span.score,
}
entities.append(current_entity)
return entities