File size: 6,234 Bytes
1061bb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from typing import Any, Dict, List

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

MAX_TOKENS_IN_BATCH = 4_000  # Hard limit to prevent OOMs
DEFAULT_MAX_NEW_TOKENS = 10  # By default limit the output to 10 tokens


class EndpointHandler:
    """
    This class is used to handle the inference with pre and post process for
    text2text models. See
    https://huggingface.co/docs/inference-endpoints/guides/custom_handler for
    more details.
    """

    def __init__(self, path: str = ""):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(path)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
        except:
            import accelerate

            print(f"ACCELERATE VERSION: {accelerate.__version__}")
            raise

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        This method is called when the endpoint is called.

        Arguments
        ---------
            data (Dict[str, Any]):
                Must contains the input data under `input` key and any
                parameters for the inference under `parameters`.

        Returns
        -------
            output (List[Dict[str, Any]]):
                A list, length equal to the number of outputted characters,
                where each item is a dictionary containing `generated_text` (i.e
                the character), `perplexity` and `first_token_probs`.
        """
        input_texts = data["inputs"]
        generate_kwargs = data.get("parameters", {})
        # This is not technically a generate_kwarg, but needs to live under parameters
        check_first_tokens = generate_kwargs.pop("check_first_tokens", None)
        max_new_tokens = (
            generate_kwargs.pop("max_new_tokens", None) or DEFAULT_MAX_NEW_TOKENS
        )

        # Tokenizing input texts
        inputs = self.tokenizer(
            input_texts, return_tensors="pt", padding=True, truncation=True,
        )["input_ids"]

        # Make sure not to OOM if too many inputs
        assert inputs.dim() == 2, f"Inputs have dimension {inputs.dim()} != 2"
        total_tokens = inputs.shape[0] * (inputs.shape[1] + max_new_tokens - 1)
        assert (
            total_tokens <= MAX_TOKENS_IN_BATCH
        ), f"Passed {total_tokens} (shape: {inputs.shape}, max_new_tokens: {max_new_tokens}), which is greater than limit of {MAX_TOKENS_IN_BATCH}"

        # Run inference on GPU
        inputs = inputs.to("cuda:0")
        with torch.no_grad():
            outputs = self.model.generate(
                inputs,
                output_scores=True,
                return_dict_in_generate=True,
                max_new_tokens=max_new_tokens,
                **generate_kwargs,
            )
        inputs = inputs.to("cpu")
        scores = [s.to("cpu") for s in outputs.scores]
        del outputs

        # process outputs
        to_return: Dict[str, Any] = {
            "generated_text": self._output_text_from_scores(scores),
            "perplexity": [float(p) for p in self._perplexity(scores)],
        }
        if check_first_tokens:
            to_return["first_token_probs"] = self._get_first_token_probs(
                check_first_tokens, scores
            )

        # Reformat output to conform to HF Pipeline format
        return [
            {key: to_return[key][ndx] for key in to_return.keys()}
            for ndx in range(len(to_return["generated_text"]))
        ]

    def _output_text_from_scores(self, scores: List[torch.Tensor]) -> List[str]:
        """
        Returns the decoded text from the scores.
        TODO (ENG-20823): Use the returned sequences so we pay attention to
        things like bad_words, force_words etc.
        """
        # Always return list format
        batch_token_ids = [
            [score[ndx].argmax() for score in scores]
            for ndx in range(scores[0].shape[0])
        ]
        # Fix for new tokens being generated after EOS
        new_batch_token_ids = []
        for token_ids in batch_token_ids:
            try:
                new_token_ids = token_ids[
                    : token_ids.index(self.tokenizer.eos_token_id)
                ]
            except ValueError:
                new_token_ids = token_ids[:-1]

            new_batch_token_ids.append(new_token_ids)
        return self.tokenizer.batch_decode(new_batch_token_ids)

    def _perplexity(self, scores: List[torch.Tensor]) -> List[float]:
        """
        Returns the perplexity (model confidence) of the outputted text.
            e^( sum(ln(p(word))) / N)
        
        TODO (ENG-20823): don't include the trailing pad tokens in perplexity
        """

        return torch.exp(
            torch.stack(
                [score.softmax(axis=1).log().max(axis=1)[0] for score in scores]
            ).sum(axis=0)
            / len(scores)
        ).tolist()

    def _get_first_token_probs(
        self, tokens: List[str], scores: List[torch.Tensor]
    ) -> List[Dict[str, float]]:
        """
        Return the softmaxed probabilities of the specific tokens for each
        output
        """
        first_token_probs = []
        softmaxed_scores = scores[0].softmax(axis=1)

        # Finding the correct token IDs
        # TODO (ENG-20824): Support multi-token words
        token_ids = {}
        for token in tokens:
            encoded_token: List[int] = self.tokenizer.encode(token)
            if len(encoded_token) > 2:
                # This means the tokenizer broke the token up into multiple parts
                token_ids[token] = -1
            else:
                token_ids[token] = encoded_token[0]

        # Now finding the scores for each token in the list
        for seq_ndx in range(scores[0].shape[0]):
            curr_token_probs: Dict[str, float] = {}

            for token in tokens:
                if token_ids[token] == -1:
                    curr_token_probs[token] = 0
                else:
                    curr_token_probs[token] = float(
                        softmaxed_scores[seq_ndx, token_ids[token]]
                    )

            first_token_probs.append(curr_token_probs)

        return first_token_probs