File size: 976 Bytes
0b048f0
146f822
 
 
 
 
5e381e7
146f822
 
0b048f0
 
146f822
0b048f0
6a28def
146f822
 
 
 
 
 
 
 
 
 
0b048f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EndpointHandler():
    def __init__(self, path=""):
        # load model and processor from path
        self.model = PegasusForConditionalGeneration.from_pretrained(path).to(device)
        self.tokenizer = AlbertTokenizer.from_pretrained(path)

    def __call__(self, data: Dict[str, str]) -> Dict[str, List[str]]:

        # process input
        wrongSentence = data.pop("inputs", data)

        # process sentence
        input = self.tokenizer(wrongSentence, return_tensors='pt', return_token_type_ids=False, return_attention_mask=False)


        # run prediction
        output = self.model.generate(input['input_ids'])

        return {"predictions": self.tokenizer.batch_decode(output, skip_special_tokens=True)}