File size: 1,388 Bytes
ec8a5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de2af79
ec8a5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing import Dict, List, Any
import logging
from transformers import AutoModelForCausalLM, AutoTokenizer


class EndpointHandler():
    def __init__(self, path=""):
        self.model = AutoModelForCausalLM.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.tokenizer.use_default_system_prompt = False

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        system_prompt = data.pop("system_prompt")
        message = data.pop("inputs")
        conversation = []
        conversation.append({"role": "system", "content": system_prompt})
        conversation.append({"role": "user", "content": message})
        raise KeyError
        logging.info(str(conversation))
        input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt")
        input_ids = input_ids.to(self.model.device)

        generate_kwargs = dict(
            {"input_ids": input_ids},
            do_sample=True,
            top_p=0.9,
            top_k=50,
            temperature=0.6,
            num_beams=1,
            repetition_penalty=1.2,
        )
        return self.model.generate(**generate_kwargs)