|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer |
|
from translate import EncoderCT2fromHfHub |
|
import os |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="", local_test=False): |
|
path = "tmp/ct2fast-e5-small-v2-hfie" |
|
snapshot_id = "1" |
|
if local_test: |
|
repo_dir = os.getcwd() |
|
else: |
|
repo_dir = "/repository/" |
|
cache_dir = os.path.join(os.path.expanduser("~/.cache/huggingface/hub/models--" + path.replace("/", "--"))) |
|
snapshot_dir = cache_dir + "/snapshots/" + snapshot_id |
|
os.makedirs(cache_dir + "/refs", exist_ok=True) |
|
os.makedirs(snapshot_dir, exist_ok=True) |
|
with open(cache_dir + "/refs/main", 'w') as filee: |
|
filee.write(snapshot_id) |
|
for filee in "config.json", "model.bin", "tokenizer_config.json", "tokenizer.json", "vocabulary.txt": |
|
|
|
link = os.path.join(snapshot_dir, filee) |
|
if not(os.path.exists(link)): os.symlink(os.path.join(repo_dir,filee), link) |
|
self.model = EncoderCT2fromHfHub( |
|
model_name_or_path=path, |
|
device="cuda", |
|
compute_type="int8_float16" |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
inputs = data.pop("inputs",data) |
|
outputs = self.model.generate(text=[inputs]) |
|
return outputs["pooler_output"].tolist() |
|
|
|
|
|
if __name__ == '__main__': |
|
from handler import EndpointHandler |
|
my_handler = EndpointHandler(path=".", local_test=True) |
|
inputs = ['The quick brown fox jumps over the lazy dog'] |
|
for input in inputs: |
|
response = my_handler({"inputs": input}) |
|
print(response) |