File size: 990 Bytes
abd01f0
037aa2f
abd01f0
 
 
 
037aa2f
abd01f0
037aa2f
abd01f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Dict, List, Any
import os
from flair.data import Sentence
from flair.models import SequenceTagger

class EndpointHandler(): 
    def __init__(self, path=str):
        #code
        self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        #code
        inputs = data.pop("inputs", data)
        sentence: Sentence = Sentence(inputs)

        self.tagger.predict(sentence, label_name="predicted")
        entities = []
        for span in sentence.get_spans("predicted"):
            if len(span.tokens) == 0:
                continue
            current_entity = {
                "entity_group": span.tag,
                "word": span.text,
                "start": span.tokens[0].start_position,
                "end": span.tokens[-1].end_position,
                "score": span.score,
            }

            entities.append(current_entity)

        return entities