nizar-sayad commited on
Commit
1d29082
1 Parent(s): bcda569

custom handler

Browse files
Files changed (2) hide show
  1. handler.py +59 -0
  2. requirements.txt +3 -0
handler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
3
+ import torch
4
+ from accelerate import Accelerator
5
+ import bitsandbytes as bnb
6
+
7
+ accelerator = Accelerator()
8
+
9
+ # Create a stopping criteria class
10
+ class KeywordsStoppingCriteria(StoppingCriteria):
11
+ def __init__(self, keywords_ids: list, occurrences: int):
12
+ super().__init__()
13
+ self.keywords = keywords_ids
14
+ self.occurrences = occurrences
15
+ self.count = 0
16
+
17
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
18
+ if input_ids[0][-1] in self.keywords:
19
+ self.count += 1
20
+ if self.count == self.occurrences:
21
+ return True
22
+ return False
23
+
24
+ class EndpointHandler:
25
+ def __init__(self, path=""):
26
+ # load model and processor from path
27
+ self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", load_in_8bit=True)
28
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
29
+
30
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
31
+ """
32
+ Args:
33
+ data (:dict:):
34
+ The payload with the text prompt.
35
+ """
36
+ # process input
37
+ input = data.pop("input", data)
38
+
39
+ stop_words = ['.']
40
+ stop_ids = [self.tokenizer.encode(w)[1] for w in stop_words]
41
+ gen_outputs = []
42
+ gen_outputs_no_input = []
43
+ gen_input = self.tokenizer(input, return_tensors="pt")
44
+ for _ in range(5):
45
+ stop_criteria = KeywordsStoppingCriteria(stop_ids, occurrences=2)
46
+ gen_output = self.model.generate(gen_input.input_ids, do_sample=True,
47
+ top_k=10,
48
+ top_p=0.95,
49
+ max_new_tokens=100,
50
+ penalty_alpha=0.6,
51
+ stopping_criteria=StoppingCriteriaList([stop_criteria])
52
+ )
53
+ gen_outputs.append(gen_output)
54
+ gen_outputs_no_input.append(gen_output[0][len(gen_input.input_ids[0]):])
55
+
56
+ gen_outputs_decoded = [self.tokenizer.decode(gen_output[0], skip_special_tokens=True) for gen_output in gen_outputs]
57
+ gen_outputs_no_input_decoded = [self.tokenizer.decode(gen_output_no_input, skip_special_tokens=True) for gen_output_no_input in gen_outputs_no_input]
58
+
59
+ return {"gen_outputs_decoded": gen_outputs_decoded, "gen_outputs_no_input_decoded": gen_outputs_no_input_decoded}
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ accelerate
2
+ bitsandbytes
3
+ transformers