File size: 10,541 Bytes
b7bd9ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os

from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType

from stopping import get_stopping
from prompter import Prompter, PromptType


class H2OTextGenerationPipeline(TextGenerationPipeline):
    def __init__(self, *args, debug=False, chat=False, stream_output=False,
                 sanitize_bot_response=False,
                 use_prompter=True, prompter=None,
                 context='', iinput='',
                 prompt_type=None, prompt_dict=None,
                 max_input_tokens=2048 - 256, **kwargs):
        """
        HF-like pipeline, but handle instruction prompting and stopping (for some models)
        :param args:
        :param debug:
        :param chat:
        :param stream_output:
        :param sanitize_bot_response:
        :param use_prompter: Whether to use prompter.  If pass prompt_type, will make prompter
        :param prompter: prompter, can pass if have already
        :param prompt_type: prompt_type, e.g. human_bot.  See prompt_type to model mapping in from prompter.py.
                            If use_prompter, then will make prompter and use it.
        :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
        :param max_input_tokens:
        :param kwargs:
        """
        super().__init__(*args, **kwargs)
        self.prompt_text = None
        self.use_prompter = use_prompter
        self.prompt_type = prompt_type
        self.prompt_dict = prompt_dict
        self.prompter = prompter
        self.context = context
        self.iinput = iinput
        if self.use_prompter:
            if self.prompter is not None:
                assert self.prompter.prompt_type is not None
            else:
                self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
                                         stream_output=stream_output)
            self.human = self.prompter.humanstr
            self.bot = self.prompter.botstr
            self.can_stop = True
        else:
            self.prompter = None
            self.human = None
            self.bot = None
            self.can_stop = False
        self.sanitize_bot_response = sanitize_bot_response
        self.max_input_tokens = max_input_tokens  # not for generate, so ok that not kwargs

    @staticmethod
    def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
        verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))

        if hasattr(tokenizer, 'model_max_length'):
            # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
            model_max_length = tokenizer.model_max_length
            if max_prompt_length is not None:
                model_max_length = min(model_max_length, max_prompt_length)
            # cut at some upper likely limit to avoid excessive tokenization etc
            # upper bound of 10 chars/token, e.g. special chars sometimes are long
            if len(prompt_text) > model_max_length * 10:
                len0 = len(prompt_text)
                prompt_text = prompt_text[-model_max_length * 10:]
                if verbose:
                    print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
        else:
            # unknown
            model_max_length = None

        num_prompt_tokens = None
        if model_max_length is not None:
            # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
            # For https://github.com/h2oai/h2ogpt/issues/192
            for trial in range(0, 3):
                prompt_tokens = tokenizer(prompt_text)['input_ids']
                num_prompt_tokens = len(prompt_tokens)
                if num_prompt_tokens > model_max_length:
                    # conservative by using int()
                    chars_per_token = int(len(prompt_text) / num_prompt_tokens)
                    # keep tail, where question is if using langchain
                    prompt_text = prompt_text[-model_max_length * chars_per_token:]
                    if verbose:
                        print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
                            num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
                else:
                    if verbose:
                        print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
                    break

            # Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
            if False:
                # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
                #
                assert num_prompt_tokens is not None
                if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
                    # then give room for prompt
                    fudge = 20
                else:
                    fudge = 0
                max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
                                            model_max_length - (num_prompt_tokens + fudge)))
                if max_new_tokens < generate_kwargs['max_new_tokens']:
                    if verbose:
                        print("Reduced max_new_tokens from %s -> %s" % (
                        generate_kwargs['max_new_tokens'], max_new_tokens))
                    generate_kwargs['max_new_tokens'] = max_new_tokens
        return prompt_text, num_prompt_tokens

    def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
        prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)

        data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
        if self.prompter is not None:
            prompt_text = self.prompter.generate_prompt(data_point)
        self.prompt_text = prompt_text
        if handle_long_generation is None:
            # forces truncation of inputs to avoid critical failure
            handle_long_generation = None  # disable with new approaches
        return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
                                  **generate_kwargs)

    def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
        records = super().postprocess(model_outputs, return_type=return_type,
                                      clean_up_tokenization_spaces=clean_up_tokenization_spaces)
        for rec in records:
            if self.use_prompter:
                outputs = rec['generated_text']
                outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
                                                     sanitize_bot_response=self.sanitize_bot_response)
            elif self.bot and self.human:
                outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
            else:
                outputs = rec['generated_text']
            rec['generated_text'] = outputs
            print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True)
        return records

    def _forward(self, model_inputs, **generate_kwargs):
        if self.can_stop:
            stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
                                             self.tokenizer, self.device,
                                             human=self.human, bot=self.bot,
                                             model_max_length=self.tokenizer.model_max_length)
            generate_kwargs['stopping_criteria'] = stopping_criteria
        # return super()._forward(model_inputs, **generate_kwargs)
        return self.__forward(model_inputs, **generate_kwargs)

    # FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
    # FIXME: https://github.com/h2oai/h2ogpt/issues/172
    def __forward(self, model_inputs, **generate_kwargs):
        input_ids = model_inputs["input_ids"]
        attention_mask = model_inputs.get("attention_mask", None)
        # Allow empty prompts
        if input_ids.shape[1] == 0:
            input_ids = None
            attention_mask = None
            in_b = 1
        else:
            in_b = input_ids.shape[0]
        prompt_text = model_inputs.pop("prompt_text")

        ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
        ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
        # generate_kwargs = copy.deepcopy(generate_kwargs)
        prefix_length = generate_kwargs.pop("prefix_length", 0)
        if prefix_length > 0:
            has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
                    "generation_config" in generate_kwargs
                    and generate_kwargs["generation_config"].max_new_tokens is not None
            )
            if not has_max_new_tokens:
                generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
                generate_kwargs["max_length"] += prefix_length
            has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
                    "generation_config" in generate_kwargs
                    and generate_kwargs["generation_config"].min_new_tokens is not None
            )
            if not has_min_new_tokens and "min_length" in generate_kwargs:
                generate_kwargs["min_length"] += prefix_length

        # BS x SL
        generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
        out_b = generated_sequence.shape[0]
        if self.framework == "pt":
            generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
        elif self.framework == "tf":
            from transformers import is_tf_available
            if is_tf_available():
                import tensorflow as tf
                generated_sequence = tf.reshape(generated_sequence,
                                                (in_b, out_b // in_b, *generated_sequence.shape[1:]))
            else:
                raise ValueError("TF not avaialble.")
        return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}