from typing import Dict, List, Any from transformers import AutoModelForSeq2SeqLM, NllbTokenizerFast class EndpointHandler(): def __init__(self, path=""): # load the optimized model self.model = AutoModelForSeq2SeqLM.from_pretrained(path,load_in_4bit=True) self.tokenizer = NllbTokenizerFast.from_pretrained(path) def __call__(self, data: Dict[str,str]) -> Dict[str, str]: """ Args: data (:obj:): includes the input data and the parameters for the inference. """ text = data.get("text", data) langId = data.get("langId",data) # tokenize the input inputs = tokenizer(text, return_tensors="pt") # run the model translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[langId], max_length=512) res = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] # return return {"translated": res}