Reizendretail commited on
Commit
2d8ab64
1 Parent(s): 8f2521b

Delete handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -40
handler.py DELETED
@@ -1,40 +0,0 @@
1
- from typing import Dict, List, Any
2
- import logging
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
-
6
- class EndpointHandler():
7
- def __init__(self, path=""):
8
- self.model = AutoModelForCausalLM.from_pretrained(path,device_map="cuda:0", load_in_4bit=True)
9
- self.tokenizer = AutoTokenizer.from_pretrained(path)
10
- self.tokenizer.use_default_system_prompt = False
11
-
12
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
- """
14
- data args:
15
- inputs (:obj: `str`)
16
- date (:obj: `str`)
17
- Return:
18
- A :obj:`list` | `dict`: will be serialized and returned
19
- """
20
- # get inputs
21
- system_prompt = data.pop("system_prompt")
22
- message = data.pop("inputs")
23
- conversation = []
24
- conversation.append({"role": "system", "content": system_prompt})
25
- conversation.append({"role": "user", "content": message})
26
- raise KeyError
27
- logging.info(str(conversation))
28
- input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt")
29
- input_ids = input_ids.to(self.model.device)
30
-
31
- generate_kwargs = dict(
32
- {"input_ids": input_ids},
33
- do_sample=True,
34
- top_p=0.9,
35
- top_k=50,
36
- temperature=0.6,
37
- num_beams=1,
38
- repetition_penalty=1.2,
39
- )
40
- return self.model.generate(**generate_kwargs)