sarang-shrivastava commited on
Commit
bedc493
·
1 Parent(s): b313bf8

Initial handler file

Browse files
Files changed (1) hide show
  1. handler.py +120 -0
handler.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import transformers
3
+ from transformers import AutoTokenizer
4
+ import torch
5
+
6
+ from transformers import StoppingCriteria, StoppingCriteriaList
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(
9
+ "",
10
+ trust_remote_code=True
11
+ )
12
+ if tokenizer.pad_token_id is None:
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+ tokenizer.padding_side = 'left'
15
+
16
+ stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
17
+
18
+
19
+ # Define a custom stopping criteria
20
+ class StopOnTokens(StoppingCriteria):
21
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22
+ for stop_id in stop_token_ids:
23
+ if input_ids[0][-1] == stop_id:
24
+ return True
25
+ return False
26
+
27
+
28
+ class EndpointHandler():
29
+ def __init__(self, path=""):
30
+
31
+ self.torch_dtype = torch.bfloat16
32
+ # self.torch_dtype = torch.float32
33
+
34
+ self.tokenizer = tokenizer
35
+
36
+ self.config = transformers.AutoConfig.from_pretrained(
37
+ path,
38
+ trust_remote_code=True
39
+ )
40
+
41
+ # self.config.attn_config['attn_impl'] = 'triton'
42
+ # self.config.update({"max_seq_len": 4096})
43
+
44
+ self.model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ path,
46
+ config=self.config,
47
+ torch_dtype=self.torch_dtype,
48
+ trust_remote_code=True
49
+ )
50
+
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ self.model.eval()
53
+ self.model.to(device=device, dtype=self.torch_dtype)
54
+
55
+ self.generate_kwargs = {
56
+ 'max_new_tokens': 512,
57
+ 'temperature': 0.0001,
58
+ 'top_p': 1.0,
59
+ 'top_k': 0,
60
+ 'use_cache': True,
61
+ 'do_sample': True,
62
+ 'eos_token_id': self.tokenizer.eos_token_id,
63
+ 'pad_token_id': self.tokenizer.pad_token_id,
64
+ "repetition_penalty": 1.1
65
+ }
66
+
67
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
68
+ """
69
+ data args:
70
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
71
+ kwargs
72
+ Return:
73
+ A :obj:`list` | `dict`: will be serialized and returned
74
+ """
75
+
76
+ # streamer = TextIteratorStreamer(
77
+ # self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
78
+ # )
79
+ stop = StopOnTokens()
80
+
81
+
82
+ ## Model Parameters
83
+ self.generate_kwargs['max_new_tokens'] = data['max_new_tokens'] if 'max_new_tokens' in data else self.generate_kwargs['max_new_tokens']
84
+ self.generate_kwargs['temperature'] = data['temperature'] if 'temperature' in data else self.generate_kwargs['temperature']
85
+ self.generate_kwargs['top_p'] = data['top_p'] if 'top_p' in data else self.generate_kwargs['top_p']
86
+ self.generate_kwargs['top_k'] = data['top_k'] if 'top_k' in data else self.generate_kwargs['top_k']
87
+ self.generate_kwargs['do_sample'] = data['do_sample'] if 'do_sample' in data else self.generate_kwargs['do_sample']
88
+ self.generate_kwargs['repetition_penalty'] = data['repetition_penalty'] if 'repetition_penalty' in data else self.generate_kwargs['repetition_penalty']
89
+
90
+ ## Add the streamer and stopping criteria
91
+ # self.generate_kwargs['streamer'] = streamer
92
+ self.generate_kwargs['stopping_criteria'] = StoppingCriteriaList([stop])
93
+
94
+ ## Prepare the inputs
95
+ inputs = data.pop("inputs",data)
96
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
97
+ input_ids = input_ids.to(self.model.device)
98
+
99
+ # encoded_inp = self.tokenizer(inputs, return_tensors='pt', padding=True)
100
+ # for key, value in encoded_inp.items():
101
+ # encoded_inp[key] = value.to('cuda:0')
102
+
103
+ ## Invoke the model
104
+ # with torch.no_grad():
105
+ # gen_tokens = self.model.generate(
106
+ # input_ids=encoded_inp['input_ids'],
107
+ # attention_mask=encoded_inp['attention_mask'],
108
+ # **generate_kwargs,
109
+ # )
110
+
111
+ # ## Decode using tokenizer
112
+ # decoded_gen = self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
113
+
114
+ with torch.no_grad():
115
+ output_ids = self.model.generate(input_ids, **self.generate_kwargs)
116
+ # Slice the output_ids tensor to get only new tokens
117
+ new_tokens = output_ids[0, len(input_ids[0]) :]
118
+ output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
119
+
120
+ return [{"gen_text":output_text, "input_text":inputs}]