from typing import Dict, List, Any import logging from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler(): def __init__(self, path=""): self.model = AutoModelForCausalLM.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.use_default_system_prompt = False def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs system_prompt = data.pop("system_prompt") message = data.pop("inputs") conversation = [] conversation.append({"role": "system", "content": system_prompt}) conversation.append({"role": "user", "content": message}) logging.info(str(conversation)) input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt") input_ids = input_ids.to(self.model.device) generate_kwargs = dict( {"input_ids": input_ids}, do_sample=True, top_p=0.9, top_k=50, temperature=0.6, num_beams=1, repetition_penalty=1.2, ) return self.model.generate(**generate_kwargs)