from typing import Dict, List, Any import os from flair.data import Sentence from flair.models import SequenceTagger class EndpointHandler(): def __init__(self, path=str): #code self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin")) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: #code inputs = data.pop("inputs", data) sentence: Sentence = Sentence(inputs) self.tagger.predict(sentence, label_name="predicted") entities = [] for span in sentence.get_spans("predicted"): if len(span.tokens) == 0: continue current_entity = { "entity_group": span.tag, "word": span.text, "start": span.tokens[0].start_position, "end": span.tokens[-1].end_position, "score": span.score, } entities.append(current_entity) return entities