Deepakvictor commited on
Commit
3da6512
·
1 Parent(s): 32681ab

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -0
handler.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path).to(device)
8
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device))
9
+
10
+ def __call__(self, data: str) -> str:
11
+ inp = self.tokenizer(data, return_tensors="pt")
12
+ for q in inp:
13
+ inp[q] = inp[q].to(device)
14
+ with torch.inference_mode():
15
+ out= model.generate(**inp)
16
+ final_output = tokenizer.batch_decode(out,skip_special_tokens=True)
17
+ return {"translation": final_output[0]}