bradleyfowler123 commited on
Commit
1061bb6
1 Parent(s): 4260712

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +169 -0
handler.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import torch
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+
6
+ MAX_TOKENS_IN_BATCH = 4_000 # Hard limit to prevent OOMs
7
+ DEFAULT_MAX_NEW_TOKENS = 10 # By default limit the output to 10 tokens
8
+
9
+
10
+ class EndpointHandler:
11
+ """
12
+ This class is used to handle the inference with pre and post process for
13
+ text2text models. See
14
+ https://huggingface.co/docs/inference-endpoints/guides/custom_handler for
15
+ more details.
16
+ """
17
+
18
+ def __init__(self, path: str = ""):
19
+ try:
20
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
21
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
22
+ except:
23
+ import accelerate
24
+
25
+ print(f"ACCELERATE VERSION: {accelerate.__version__}")
26
+ raise
27
+
28
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
+ """
30
+ This method is called when the endpoint is called.
31
+
32
+ Arguments
33
+ ---------
34
+ data (Dict[str, Any]):
35
+ Must contains the input data under `input` key and any
36
+ parameters for the inference under `parameters`.
37
+
38
+ Returns
39
+ -------
40
+ output (List[Dict[str, Any]]):
41
+ A list, length equal to the number of outputted characters,
42
+ where each item is a dictionary containing `generated_text` (i.e
43
+ the character), `perplexity` and `first_token_probs`.
44
+ """
45
+ input_texts = data["inputs"]
46
+ generate_kwargs = data.get("parameters", {})
47
+ # This is not technically a generate_kwarg, but needs to live under parameters
48
+ check_first_tokens = generate_kwargs.pop("check_first_tokens", None)
49
+ max_new_tokens = (
50
+ generate_kwargs.pop("max_new_tokens", None) or DEFAULT_MAX_NEW_TOKENS
51
+ )
52
+
53
+ # Tokenizing input texts
54
+ inputs = self.tokenizer(
55
+ input_texts, return_tensors="pt", padding=True, truncation=True,
56
+ )["input_ids"]
57
+
58
+ # Make sure not to OOM if too many inputs
59
+ assert inputs.dim() == 2, f"Inputs have dimension {inputs.dim()} != 2"
60
+ total_tokens = inputs.shape[0] * (inputs.shape[1] + max_new_tokens - 1)
61
+ assert (
62
+ total_tokens <= MAX_TOKENS_IN_BATCH
63
+ ), f"Passed {total_tokens} (shape: {inputs.shape}, max_new_tokens: {max_new_tokens}), which is greater than limit of {MAX_TOKENS_IN_BATCH}"
64
+
65
+ # Run inference on GPU
66
+ inputs = inputs.to("cuda:0")
67
+ with torch.no_grad():
68
+ outputs = self.model.generate(
69
+ inputs,
70
+ output_scores=True,
71
+ return_dict_in_generate=True,
72
+ max_new_tokens=max_new_tokens,
73
+ **generate_kwargs,
74
+ )
75
+ inputs = inputs.to("cpu")
76
+ scores = [s.to("cpu") for s in outputs.scores]
77
+ del outputs
78
+
79
+ # process outputs
80
+ to_return: Dict[str, Any] = {
81
+ "generated_text": self._output_text_from_scores(scores),
82
+ "perplexity": [float(p) for p in self._perplexity(scores)],
83
+ }
84
+ if check_first_tokens:
85
+ to_return["first_token_probs"] = self._get_first_token_probs(
86
+ check_first_tokens, scores
87
+ )
88
+
89
+ # Reformat output to conform to HF Pipeline format
90
+ return [
91
+ {key: to_return[key][ndx] for key in to_return.keys()}
92
+ for ndx in range(len(to_return["generated_text"]))
93
+ ]
94
+
95
+ def _output_text_from_scores(self, scores: List[torch.Tensor]) -> List[str]:
96
+ """
97
+ Returns the decoded text from the scores.
98
+ TODO (ENG-20823): Use the returned sequences so we pay attention to
99
+ things like bad_words, force_words etc.
100
+ """
101
+ # Always return list format
102
+ batch_token_ids = [
103
+ [score[ndx].argmax() for score in scores]
104
+ for ndx in range(scores[0].shape[0])
105
+ ]
106
+ # Fix for new tokens being generated after EOS
107
+ new_batch_token_ids = []
108
+ for token_ids in batch_token_ids:
109
+ try:
110
+ new_token_ids = token_ids[
111
+ : token_ids.index(self.tokenizer.eos_token_id)
112
+ ]
113
+ except ValueError:
114
+ new_token_ids = token_ids[:-1]
115
+
116
+ new_batch_token_ids.append(new_token_ids)
117
+ return self.tokenizer.batch_decode(new_batch_token_ids)
118
+
119
+ def _perplexity(self, scores: List[torch.Tensor]) -> List[float]:
120
+ """
121
+ Returns the perplexity (model confidence) of the outputted text.
122
+ e^( sum(ln(p(word))) / N)
123
+
124
+ TODO (ENG-20823): don't include the trailing pad tokens in perplexity
125
+ """
126
+
127
+ return torch.exp(
128
+ torch.stack(
129
+ [score.softmax(axis=1).log().max(axis=1)[0] for score in scores]
130
+ ).sum(axis=0)
131
+ / len(scores)
132
+ ).tolist()
133
+
134
+ def _get_first_token_probs(
135
+ self, tokens: List[str], scores: List[torch.Tensor]
136
+ ) -> List[Dict[str, float]]:
137
+ """
138
+ Return the softmaxed probabilities of the specific tokens for each
139
+ output
140
+ """
141
+ first_token_probs = []
142
+ softmaxed_scores = scores[0].softmax(axis=1)
143
+
144
+ # Finding the correct token IDs
145
+ # TODO (ENG-20824): Support multi-token words
146
+ token_ids = {}
147
+ for token in tokens:
148
+ encoded_token: List[int] = self.tokenizer.encode(token)
149
+ if len(encoded_token) > 2:
150
+ # This means the tokenizer broke the token up into multiple parts
151
+ token_ids[token] = -1
152
+ else:
153
+ token_ids[token] = encoded_token[0]
154
+
155
+ # Now finding the scores for each token in the list
156
+ for seq_ndx in range(scores[0].shape[0]):
157
+ curr_token_probs: Dict[str, float] = {}
158
+
159
+ for token in tokens:
160
+ if token_ids[token] == -1:
161
+ curr_token_probs[token] = 0
162
+ else:
163
+ curr_token_probs[token] = float(
164
+ softmaxed_scores[seq_ndx, token_ids[token]]
165
+ )
166
+
167
+ first_token_probs.append(curr_token_probs)
168
+
169
+ return first_token_probs