from typing import Dict, List, Any from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer import torch class EndpointHandler(): def __init__(self, path=""): # load the optimized model self.model = M2M100ForConditionalGeneration.from_pretrained(path,torch_dtype=torch.bfloat16) self.tokenizer = M2M100Tokenizer.from_pretrained(path) def __call__(self, data: Dict[str,str]) -> Dict[str, str]: """ Args: data (:obj:): includes the input data and the parameters for the inference. """ text = data.get("text", data) langId = data.get("langId",data) # tokenize the input encoded = tokenizer(text, return_tensors="pt") encoded = encoded.to(model.device) # run the model generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(langId)) result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] # return return {"translated": result}