File size: 9,142 Bytes
eeb7ca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os

from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType

from stopping import get_stopping
from prompter import Prompter, PromptType


class H2OTextGenerationPipeline(TextGenerationPipeline):
    def __init__(self, *args, debug=False, chat=False, stream_output=False,
                 sanitize_bot_response=True,
                 use_prompter=True, prompter=None,
                 prompt_type=None, prompt_dict=None,
                 max_input_tokens=2048 - 256, **kwargs):
        """
        HF-like pipeline, but handle instruction prompting and stopping (for some models)
        :param args:
        :param debug:
        :param chat:
        :param stream_output:
        :param sanitize_bot_response:
        :param use_prompter: Whether to use prompter.  If pass prompt_type, will make prompter
        :param prompter: prompter, can pass if have already
        :param prompt_type: prompt_type, e.g. human_bot.  See prompt_type to model mapping in from prompter.py.
                            If use_prompter, then will make prompter and use it.
        :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
        :param max_input_tokens:
        :param kwargs:
        """
        super().__init__(*args, **kwargs)
        self.prompt_text = None
        self.use_prompter = use_prompter
        self.prompt_type = prompt_type
        self.prompt_dict = prompt_dict
        self.prompter = prompter
        if self.use_prompter:
            if self.prompter is not None:
                assert self.prompter.prompt_type is not None
            else:
                self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
                                         stream_output=stream_output)
            self.human = self.prompter.humanstr
            self.bot = self.prompter.botstr
            self.can_stop = True
        else:
            self.prompter = None
            self.human = None
            self.bot = None
            self.can_stop = False
        self.sanitize_bot_response = sanitize_bot_response
        self.max_input_tokens = max_input_tokens  # not for generate, so ok that not kwargs

    def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
        if hasattr(self.tokenizer, 'model_max_length'):
            # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
            model_max_length = self.tokenizer.model_max_length
        else:
            # unknown
            model_max_length = None

        verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
        if model_max_length is not None:
            num_prompt_tokens = None
            # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
            # For https://github.com/h2oai/h2ogpt/issues/192
            for trial in range(0, 3):
                prompt_tokens = self.tokenizer(prompt_text)['input_ids']
                num_prompt_tokens = len(prompt_tokens)
                if num_prompt_tokens > model_max_length:
                    # conservative by using int()
                    chars_per_token = int(len(prompt_text) / num_prompt_tokens)
                    prompt_text = prompt_text[-model_max_length * chars_per_token:]
                    if verbose:
                        print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
                            num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
                else:
                    if verbose:
                        print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
                    break

            # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
            #
            if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
                # then give room for prompt
                fudge = 20
            else:
                fudge = 0
            assert num_prompt_tokens is not None
            max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
                                        model_max_length - (num_prompt_tokens + fudge)))
            if max_new_tokens < generate_kwargs['max_new_tokens']:
                if verbose:
                    print("Reduced max_new_tokens from %s -> %s" % (generate_kwargs['max_new_tokens'], max_new_tokens))
                generate_kwargs['max_new_tokens'] = max_new_tokens

        data_point = dict(context='', instruction=prompt_text, input='')
        if self.prompter is not None:
            prompt_text = self.prompter.generate_prompt(data_point)
        self.prompt_text = prompt_text
        if handle_long_generation is None:
            # forces truncation of inputs to avoid critical failure
            handle_long_generation = 'hole'
        return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
                                  **generate_kwargs)

    def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
        records = super().postprocess(model_outputs, return_type=return_type,
                                      clean_up_tokenization_spaces=clean_up_tokenization_spaces)
        for rec in records:
            if self.use_prompter:
                outputs = rec['generated_text']
                outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
                                                     sanitize_bot_response=self.sanitize_bot_response)
            elif self.bot and self.human:
                outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
            else:
                outputs = rec['generated_text']
            rec['generated_text'] = outputs
        return records

    def _forward(self, model_inputs, **generate_kwargs):
        if self.can_stop:
            stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
                                             self.tokenizer, self.device,
                                             human=self.human, bot=self.bot)
            generate_kwargs['stopping_criteria'] = stopping_criteria
        # return super()._forward(model_inputs, **generate_kwargs)
        return self.__forward(model_inputs, **generate_kwargs)

    # FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
    # FIXME: https://github.com/h2oai/h2ogpt/issues/172
    def __forward(self, model_inputs, **generate_kwargs):
        input_ids = model_inputs["input_ids"]
        attention_mask = model_inputs.get("attention_mask", None)
        # Allow empty prompts
        if input_ids.shape[1] == 0:
            input_ids = None
            attention_mask = None
            in_b = 1
        else:
            in_b = input_ids.shape[0]
        prompt_text = model_inputs.pop("prompt_text")

        ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
        ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
        # generate_kwargs = copy.deepcopy(generate_kwargs)
        prefix_length = generate_kwargs.pop("prefix_length", 0)
        if prefix_length > 0:
            has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
                    "generation_config" in generate_kwargs
                    and generate_kwargs["generation_config"].max_new_tokens is not None
            )
            if not has_max_new_tokens:
                generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
                generate_kwargs["max_length"] += prefix_length
            has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
                    "generation_config" in generate_kwargs
                    and generate_kwargs["generation_config"].min_new_tokens is not None
            )
            if not has_min_new_tokens and "min_length" in generate_kwargs:
                generate_kwargs["min_length"] += prefix_length

        # BS x SL
        generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
        out_b = generated_sequence.shape[0]
        if self.framework == "pt":
            generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
        elif self.framework == "tf":
            from transformers import is_tf_available
            if is_tf_available():
                import tensorflow as tf
                generated_sequence = tf.reshape(generated_sequence,
                                                (in_b, out_b // in_b, *generated_sequence.shape[1:]))
            else:
                raise ValueError("TF not avaialble.")
        return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}