File size: 1,388 Bytes
ec8a5c6 de2af79 ec8a5c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from typing import Dict, List, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForCausalLM.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.use_default_system_prompt = False
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
system_prompt = data.pop("system_prompt")
message = data.pop("inputs")
conversation = []
conversation.append({"role": "system", "content": system_prompt})
conversation.append({"role": "user", "content": message})
raise KeyError
logging.info(str(conversation))
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt")
input_ids = input_ids.to(self.model.device)
generate_kwargs = dict(
{"input_ids": input_ids},
do_sample=True,
top_p=0.9,
top_k=50,
temperature=0.6,
num_beams=1,
repetition_penalty=1.2,
)
return self.model.generate(**generate_kwargs) |