bhai_insights_v2 / handler.py
Reizendretail's picture
Update handler.py
c94c5f0
raw
history blame
No virus
1.37 kB
from typing import Dict, List, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForCausalLM.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.use_default_system_prompt = False
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
system_prompt = data.pop("system_prompt")
message = data.pop("inputs")
conversation = []
conversation.append({"role": "system", "content": system_prompt})
conversation.append({"role": "user", "content": message})
logging.info(str(conversation))
input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt")
input_ids = input_ids.to(self.model.device)
generate_kwargs = dict(
{"input_ids": input_ids},
do_sample=True,
top_p=0.9,
top_k=50,
temperature=0.6,
num_beams=1,
repetition_penalty=1.2,
)
return self.model.generate(**generate_kwargs)