|
from typing import Dict, List, Any |
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
import torch |
|
|
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.model = M2M100ForConditionalGeneration.from_pretrained(path,torch_dtype=torch.bfloat16) |
|
self.tokenizer = M2M100Tokenizer.from_pretrained(path) |
|
|
|
def __call__(self, data: Dict[str,str]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the input data and the parameters for the inference. |
|
""" |
|
text = data.get("text", data) |
|
langId = data.get("langId",data) |
|
|
|
|
|
encoded = tokenizer(text, return_tensors="pt") |
|
encoded = encoded.to(model.device) |
|
|
|
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(langId)) |
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
return {"translated": result} |