Create handler.py
Browse files- handler.py +48 -0
handler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from translate import EncoderCT2fromHfHub
|
4 |
+
import os
|
5 |
+
|
6 |
+
class EndpointHandler():
|
7 |
+
def __init__(self, path="", local_test=False):
|
8 |
+
path = "tmp/ct2fast-e5-small-v2-hfie"
|
9 |
+
snapshot_id = "1"
|
10 |
+
if local_test:
|
11 |
+
repo_dir = os.getcwd()
|
12 |
+
else:
|
13 |
+
repo_dir = "/repository/"
|
14 |
+
cache_dir = os.path.join(os.path.expanduser("~/.cache/huggingface/hub/models--" + path.replace("/", "--")))
|
15 |
+
snapshot_dir = cache_dir + "/snapshots/" + snapshot_id
|
16 |
+
os.makedirs(cache_dir + "/refs", exist_ok=True)
|
17 |
+
os.makedirs(snapshot_dir, exist_ok=True)
|
18 |
+
with open(cache_dir + "/refs/main", 'w') as filee:
|
19 |
+
filee.write(snapshot_id)
|
20 |
+
for filee in "config.json", "model.bin", "tokenizer_config.json", "tokenizer.json", "vocabulary.txt":
|
21 |
+
# Make symbolic links
|
22 |
+
link = os.path.join(snapshot_dir, filee)
|
23 |
+
if not(os.path.exists(link)): os.symlink(os.path.join(repo_dir,filee), link)
|
24 |
+
self.model = EncoderCT2fromHfHub(
|
25 |
+
model_name_or_path=path,
|
26 |
+
device="cuda",
|
27 |
+
compute_type="int8_float16"
|
28 |
+
)
|
29 |
+
|
30 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
31 |
+
"""
|
32 |
+
data args:
|
33 |
+
inputs (:obj: `str`)
|
34 |
+
Return:
|
35 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
36 |
+
"""
|
37 |
+
inputs = data.pop("inputs",data)
|
38 |
+
outputs = self.model.generate(text=[inputs])
|
39 |
+
return outputs["pooler_output"].tolist()
|
40 |
+
|
41 |
+
# Test code
|
42 |
+
if __name__ == '__main__':
|
43 |
+
from handler import EndpointHandler
|
44 |
+
my_handler = EndpointHandler(path=".", local_test=True)
|
45 |
+
inputs = ['The quick brown fox jumps over the lazy dog']
|
46 |
+
for input in inputs:
|
47 |
+
response = my_handler({"inputs": input})
|
48 |
+
print(response)
|