from typing import Dict, List, Any from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class EndpointHandler(): def __init__(self, path=""): # load model and processor from path self.model = PegasusForConditionalGeneration.from_pretrained(path).to(device) self.tokenizer = AlbertTokenizer.from_pretrained(path) def __call__(self, data: Dict[str, str]) -> Dict[str, List[str]]: # process input wrongSentence = data.pop("inputs", data) # process sentence input = self.tokenizer(wrongSentence, return_tensors='pt', return_token_type_ids=False, return_attention_mask=False) # run prediction output = self.model.generate(input['input_ids']) return {"predictions": self.tokenizer.batch_decode(output, skip_special_tokens=True)}