|
from typing import Dict, List, Any |
|
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor, PegasusForConditionalGeneration, AlbertTokenizer |
|
import torch |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.model = PegasusForConditionalGeneration.from_pretrained(path).to(device) |
|
self.tokenizer = AlbertTokenizer.from_pretrained(path) |
|
|
|
def __call__(self, data: Dict[str, str]) -> Dict[str, List[str]]: |
|
|
|
|
|
wrongSentence = data.pop("inputs", data) |
|
|
|
|
|
input = self.tokenizer(wrongSentence, return_tensors='pt', return_token_type_ids=False, return_attention_mask=False) |
|
|
|
|
|
|
|
output = self.model.generate(input['input_ids']) |
|
|
|
return {"predictions": self.tokenizer.batch_decode(output, skip_special_tokens=True)} |
|
|