ikeno-ada commited on
Commit
7b841fa
1 Parent(s): af4df46

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -0
handler.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import intel_extension_for_pytorch as ipex
2
+ from typing import Dict, List, Any
3
+ from transformers import AutoModelForSeq2SeqLM, NllbTokenizerFast
4
+ import torch
5
+
6
+
7
+
8
+ class EndpointHandler():
9
+ def __init__(self, path=""):
10
+ # load the optimized model
11
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path,torch_dtype=torch.bfloat16)
12
+ self.tokenizer = NllbTokenizerFast.from_pretrained(path)
13
+
14
+ def __call__(self, data: Dict[str,str]) -> Dict[str, str]:
15
+ """
16
+ Args:
17
+ data (:obj:):
18
+ includes the input data and the parameters for the inference.
19
+ """
20
+ text = data.get("text", data)
21
+ langId = data.get("langId",data)
22
+
23
+ # tokenize the input
24
+ inputs = tokenizer(text, return_tensors="pt")
25
+ # run the model
26
+ translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[langId], max_length=512)
27
+ res = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
28
+ # return
29
+ return {"translated": res}