from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch class EndpointHandler(): def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.src_lang = "en" self.tokenizer.tgt_lang = "ta" self.model = AutoModelForSeq2SeqLM.from_pretrained(path) def __call__(self, data: str) -> str: inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) inp = self.tokenizer(inputs, return_tensors="pt") with torch.inference_mode(): out= self.model.generate(**inp) final_output = self.tokenizer.batch_decode(out,skip_special_tokens=True) return {"translation": final_output[0]}