popaqy's picture
Update handler.py
0b048f0
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler():
def __init__(self, path=""):
# load model and processor from path
self.model = PegasusForConditionalGeneration.from_pretrained(path).to(device)
self.tokenizer = AlbertTokenizer.from_pretrained(path)
def __call__(self, data: Dict[str, str]) -> Dict[str, List[str]]:
# process input
wrongSentence = data.pop("inputs", data)
# process sentence
input = self.tokenizer(wrongSentence, return_tensors='pt', return_token_type_ids=False, return_attention_mask=False)
# run prediction
output = self.model.generate(input['input_ids'])
return {"predictions": self.tokenizer.batch_decode(output, skip_special_tokens=True)}