Reizendretail commited on
Commit
b80fb09
1 Parent(s): c94c5f0

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -39
handler.py DELETED
@@ -1,39 +0,0 @@
1
- from typing import Dict, List, Any
2
- import logging
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
-
6
- class EndpointHandler():
7
- def __init__(self, path=""):
8
- self.model = AutoModelForCausalLM.from_pretrained(path)
9
- self.tokenizer = AutoTokenizer.from_pretrained(path)
10
- self.tokenizer.use_default_system_prompt = False
11
-
12
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
- """
14
- data args:
15
- inputs (:obj: `str`)
16
- date (:obj: `str`)
17
- Return:
18
- A :obj:`list` | `dict`: will be serialized and returned
19
- """
20
- # get inputs
21
- system_prompt = data.pop("system_prompt")
22
- message = data.pop("inputs")
23
- conversation = []
24
- conversation.append({"role": "system", "content": system_prompt})
25
- conversation.append({"role": "user", "content": message})
26
- logging.info(str(conversation))
27
- input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt")
28
- input_ids = input_ids.to(self.model.device)
29
-
30
- generate_kwargs = dict(
31
- {"input_ids": input_ids},
32
- do_sample=True,
33
- top_p=0.9,
34
- top_k=50,
35
- temperature=0.6,
36
- num_beams=1,
37
- repetition_penalty=1.2,
38
- )
39
- return self.model.generate(**generate_kwargs)