ikeno-ada's picture
Update handler.py
1bac627 verified
raw
history blame
1.06 kB
from typing import Dict, List, Any
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import torch
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
self.model = M2M100ForConditionalGeneration.from_pretrained(path,torch_dtype=torch.bfloat16)
self.tokenizer = M2M100Tokenizer.from_pretrained(path)
def __call__(self, data: Dict[str,str]) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
"""
text = data.get("text", data)
langId = data.get("langId",data)
# tokenize the input
encoded = tokenizer(text, return_tensors="pt")
encoded = encoded.to(model.device)
# run the model
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(langId))
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# return
return {"translated": result}