TheBloke commited on
Commit
99e0f85
·
1 Parent(s): db3cc40

Initial GPTQ model commit

Browse files
Files changed (1) hide show
  1. h2oai_pipeline.py +929 -0
h2oai_pipeline.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import TextGenerationPipeline
4
+ from transformers.pipelines.text_generation import ReturnType
5
+
6
+
7
+
8
+
9
+
10
+ class H2OTextGenerationPipeline(TextGenerationPipeline):
11
+ def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
+ sanitize_bot_response=False,
13
+ use_prompter=True, prompter=None,
14
+ prompt_type=None, prompt_dict=None,
15
+ max_input_tokens=2048 - 256, **kwargs):
16
+ """
17
+ HF-like pipeline, but handle instruction prompting and stopping (for some models)
18
+ :param args:
19
+ :param debug:
20
+ :param chat:
21
+ :param stream_output:
22
+ :param sanitize_bot_response:
23
+ :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
24
+ :param prompter: prompter, can pass if have already
25
+ :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in
26
+ If use_prompter, then will make prompter and use it.
27
+ :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
28
+ :param max_input_tokens:
29
+ :param kwargs:
30
+ """
31
+ super().__init__(*args, **kwargs)
32
+ self.prompt_text = None
33
+ self.use_prompter = use_prompter
34
+ self.prompt_type = prompt_type
35
+ self.prompt_dict = prompt_dict
36
+ self.prompter = prompter
37
+ if self.use_prompter:
38
+ if self.prompter is not None:
39
+ assert self.prompter.prompt_type is not None
40
+ else:
41
+ self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
42
+ stream_output=stream_output)
43
+ self.human = self.prompter.humanstr
44
+ self.bot = self.prompter.botstr
45
+ self.can_stop = True
46
+ else:
47
+ self.prompter = None
48
+ self.human = None
49
+ self.bot = None
50
+ self.can_stop = False
51
+ self.sanitize_bot_response = sanitize_bot_response
52
+ self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
53
+
54
+ @staticmethod
55
+ def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
56
+ verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
57
+
58
+ if hasattr(tokenizer, 'model_max_length'):
59
+ # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
60
+ model_max_length = tokenizer.model_max_length
61
+ if max_prompt_length is not None:
62
+ model_max_length = min(model_max_length, max_prompt_length)
63
+ # cut at some upper likely limit to avoid excessive tokenization etc
64
+ # upper bound of 10 chars/token, e.g. special chars sometimes are long
65
+ if len(prompt_text) > model_max_length * 10:
66
+ len0 = len(prompt_text)
67
+ prompt_text = prompt_text[-model_max_length * 10:]
68
+ if verbose:
69
+ print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
70
+ else:
71
+ # unknown
72
+ model_max_length = None
73
+
74
+ if model_max_length is not None:
75
+ num_prompt_tokens = None
76
+ # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
77
+ # For https://github.com/h2oai/h2ogpt/issues/192
78
+ for trial in range(0, 3):
79
+ prompt_tokens = tokenizer(prompt_text)['input_ids']
80
+ num_prompt_tokens = len(prompt_tokens)
81
+ if num_prompt_tokens > model_max_length:
82
+ # conservative by using int()
83
+ chars_per_token = int(len(prompt_text) / num_prompt_tokens)
84
+ # keep tail, where question is if using langchain
85
+ prompt_text = prompt_text[-model_max_length * chars_per_token:]
86
+ if verbose:
87
+ print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
88
+ num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
89
+ else:
90
+ if verbose:
91
+ print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
92
+ break
93
+
94
+ # Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
95
+ if False:
96
+ # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
97
+ #
98
+ assert num_prompt_tokens is not None
99
+ if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
100
+ # then give room for prompt
101
+ fudge = 20
102
+ else:
103
+ fudge = 0
104
+ max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
105
+ model_max_length - (num_prompt_tokens + fudge)))
106
+ if max_new_tokens < generate_kwargs['max_new_tokens']:
107
+ if verbose:
108
+ print("Reduced max_new_tokens from %s -> %s" % (
109
+ generate_kwargs['max_new_tokens'], max_new_tokens))
110
+ generate_kwargs['max_new_tokens'] = max_new_tokens
111
+ return prompt_text
112
+
113
+ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
114
+ prompt_text = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
115
+
116
+ data_point = dict(context='', instruction=prompt_text, input='')
117
+ if self.prompter is not None:
118
+ prompt_text = self.prompter.generate_prompt(data_point)
119
+ self.prompt_text = prompt_text
120
+ if handle_long_generation is None:
121
+ # forces truncation of inputs to avoid critical failure
122
+ handle_long_generation = None # disable with new approaches
123
+ return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
124
+ **generate_kwargs)
125
+
126
+ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
127
+ records = super().postprocess(model_outputs, return_type=return_type,
128
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces)
129
+ for rec in records:
130
+ if self.use_prompter:
131
+ outputs = rec['generated_text']
132
+ outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
133
+ sanitize_bot_response=self.sanitize_bot_response)
134
+ elif self.bot and self.human:
135
+ outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
136
+ else:
137
+ outputs = rec['generated_text']
138
+ rec['generated_text'] = outputs
139
+ return records
140
+
141
+ def _forward(self, model_inputs, **generate_kwargs):
142
+ if self.can_stop:
143
+ stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
144
+ self.tokenizer, self.device,
145
+ human=self.human, bot=self.bot,
146
+ model_max_length=self.tokenizer.model_max_length)
147
+ generate_kwargs['stopping_criteria'] = stopping_criteria
148
+ # return super()._forward(model_inputs, **generate_kwargs)
149
+ return self.__forward(model_inputs, **generate_kwargs)
150
+
151
+ # FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
152
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/172
153
+ def __forward(self, model_inputs, **generate_kwargs):
154
+ input_ids = model_inputs["input_ids"]
155
+ attention_mask = model_inputs.get("attention_mask", None)
156
+ # Allow empty prompts
157
+ if input_ids.shape[1] == 0:
158
+ input_ids = None
159
+ attention_mask = None
160
+ in_b = 1
161
+ else:
162
+ in_b = input_ids.shape[0]
163
+ prompt_text = model_inputs.pop("prompt_text")
164
+
165
+ ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
166
+ ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
167
+ # generate_kwargs = copy.deepcopy(generate_kwargs)
168
+ prefix_length = generate_kwargs.pop("prefix_length", 0)
169
+ if prefix_length > 0:
170
+ has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
171
+ "generation_config" in generate_kwargs
172
+ and generate_kwargs["generation_config"].max_new_tokens is not None
173
+ )
174
+ if not has_max_new_tokens:
175
+ generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
176
+ generate_kwargs["max_length"] += prefix_length
177
+ has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
178
+ "generation_config" in generate_kwargs
179
+ and generate_kwargs["generation_config"].min_new_tokens is not None
180
+ )
181
+ if not has_min_new_tokens and "min_length" in generate_kwargs:
182
+ generate_kwargs["min_length"] += prefix_length
183
+
184
+ # BS x SL
185
+ generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
186
+ out_b = generated_sequence.shape[0]
187
+ if self.framework == "pt":
188
+ generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
189
+ elif self.framework == "tf":
190
+ from transformers import is_tf_available
191
+ if is_tf_available():
192
+ import tensorflow as tf
193
+ generated_sequence = tf.reshape(generated_sequence,
194
+ (in_b, out_b // in_b, *generated_sequence.shape[1:]))
195
+ else:
196
+ raise ValueError("TF not avaialble.")
197
+ return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
198
+ import torch
199
+ from transformers import StoppingCriteria, StoppingCriteriaList
200
+
201
+
202
+
203
+ class StoppingCriteriaSub(StoppingCriteria):
204
+
205
+ def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
206
+ super().__init__()
207
+ assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
208
+ self.encounters = encounters
209
+ self.stops = [stop.to(device) for stop in stops]
210
+ self.num_stops = [0] * len(stops)
211
+ self.model_max_length = model_max_length
212
+
213
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
214
+ for stopi, stop in enumerate(self.stops):
215
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
216
+ self.num_stops[stopi] += 1
217
+ if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
218
+ # print("Stopped", flush=True)
219
+ return True
220
+ if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
221
+ # critical limit
222
+ return True
223
+ # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
224
+ # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
225
+ return False
226
+
227
+
228
+ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
229
+ # FIXME: prompt_dict unused currently
230
+ if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
231
+ if prompt_type == PromptType.human_bot.name:
232
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
233
+ # stopping only starts once output is beyond prompt
234
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
235
+ stop_words = [human, bot, '\n' + human, '\n' + bot]
236
+ encounters = [1, 2]
237
+ elif prompt_type == PromptType.instruct_vicuna.name:
238
+ # even below is not enough, generic strings and many ways to encode
239
+ stop_words = [
240
+ '### Human:',
241
+ """
242
+ ### Human:""",
243
+ """
244
+ ### Human:
245
+ """,
246
+ '### Assistant:',
247
+ """
248
+ ### Assistant:""",
249
+ """
250
+ ### Assistant:
251
+ """,
252
+ ]
253
+ encounters = [1, 2]
254
+ else:
255
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
256
+ stop_words = ['### End']
257
+ encounters = [1]
258
+ stop_words_ids = [
259
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
260
+ # handle single token case
261
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
262
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
263
+ # avoid padding in front of tokens
264
+ if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
265
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
266
+ # handle fake \n added
267
+ stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
268
+ # build stopper
269
+ stopping_criteria = StoppingCriteriaList(
270
+ [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
271
+ model_max_length=model_max_length)])
272
+ else:
273
+ stopping_criteria = StoppingCriteriaList()
274
+ return stopping_criteria
275
+ from enum import Enum
276
+
277
+
278
+ class PromptType(Enum):
279
+ custom = -1
280
+ plain = 0
281
+ instruct = 1
282
+ quality = 2
283
+ human_bot = 3
284
+ dai_faq = 4
285
+ summarize = 5
286
+ simple_instruct = 6
287
+ instruct_vicuna = 7
288
+ instruct_with_end = 8
289
+ human_bot_orig = 9
290
+ prompt_answer = 10
291
+ open_assistant = 11
292
+ wizard_lm = 12
293
+ wizard_mega = 13
294
+ instruct_vicuna2 = 14
295
+ instruct_vicuna3 = 15
296
+ wizard2 = 16
297
+ wizard3 = 17
298
+ instruct_simple = 18
299
+
300
+
301
+ class DocumentChoices(Enum):
302
+ All_Relevant = 0
303
+ All_Relevant_Only_Sources = 1
304
+ Only_All_Sources = 2
305
+ Just_LLM = 3
306
+
307
+
308
+ class LangChainMode(Enum):
309
+ """LangChain mode"""
310
+
311
+ DISABLED = "Disabled"
312
+ CHAT_LLM = "ChatLLM"
313
+ LLM = "LLM"
314
+ ALL = "All"
315
+ WIKI = "wiki"
316
+ WIKI_FULL = "wiki_full"
317
+ USER_DATA = "UserData"
318
+ MY_DATA = "MyData"
319
+ GITHUB_H2OGPT = "github h2oGPT"
320
+ H2O_DAI_DOCS = "DriverlessAI docs"
321
+ import ast
322
+ import time
323
+ from enums import PromptType # also supports imports from this file from other files
324
+
325
+ non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
326
+
327
+ prompt_type_to_model_name = {
328
+ 'plain': [
329
+ 'EleutherAI/gpt-j-6B',
330
+ 'EleutherAI/pythia-6.9b',
331
+ 'EleutherAI/pythia-12b',
332
+ 'EleutherAI/pythia-12b-deduped',
333
+ 'EleutherAI/gpt-neox-20b',
334
+ 'openlm-research/open_llama_7b_700bt_preview',
335
+ 'decapoda-research/llama-7b-hf',
336
+ 'decapoda-research/llama-13b-hf',
337
+ 'decapoda-research/llama-30b-hf',
338
+ 'decapoda-research/llama-65b-hf',
339
+ 'facebook/mbart-large-50-many-to-many-mmt',
340
+ 'philschmid/bart-large-cnn-samsum',
341
+ 'philschmid/flan-t5-base-samsum',
342
+ 'gpt2',
343
+ 'distilgpt2',
344
+ 'mosaicml/mpt-7b-storywriter',
345
+ 'mosaicml/mpt-7b-instruct', # internal code handles instruct
346
+ 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
347
+ 'gptj', # internally handles prompting
348
+ 'llama', # plain, or need to choose prompt_type for given TheBloke model
349
+ 'gpt4all_llama', # internally handles prompting
350
+ ],
351
+ 'prompt_answer': [
352
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
353
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
354
+ 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
355
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
356
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
357
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
358
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
359
+ 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
360
+ 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
361
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
362
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
363
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
364
+ ],
365
+ 'instruct': [],
366
+ 'instruct_with_end': ['databricks/dolly-v2-12b'],
367
+ 'quality': [],
368
+ 'human_bot': [
369
+ 'h2oai/h2ogpt-oasst1-512-12b',
370
+ 'h2oai/h2ogpt-oasst1-512-20b',
371
+ 'h2oai/h2ogpt-oig-oasst1-256-6_9b',
372
+ 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
373
+ 'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
374
+ 'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
375
+ 'h2oai/h2ogpt-research-oasst1-512-30b',
376
+ 'h2oai/h2ogpt-oasst1-falcon-40b',
377
+ 'h2oai/h2ogpt-oig-oasst1-falcon-40b',
378
+ ],
379
+ 'dai_faq': [],
380
+ 'summarize': [],
381
+ 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
382
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
383
+ 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
384
+ "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
385
+ "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
386
+ "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
387
+ "instruct_simple": ['JosephusCheung/Guanaco'],
388
+ }
389
+
390
+ inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
391
+ inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
392
+
393
+ prompt_types_strings = []
394
+ for p in PromptType:
395
+ prompt_types_strings.extend([p.name])
396
+
397
+ prompt_types = []
398
+ for p in PromptType:
399
+ prompt_types.extend([p.name, p.value, str(p.value)])
400
+
401
+
402
+ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False):
403
+ prompt_dict_error = ''
404
+ if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
405
+ try:
406
+ prompt_dict = ast.literal_eval(prompt_dict)
407
+ except BaseException as e:
408
+ prompt_dict_error = str(e)
409
+ if prompt_dict_error:
410
+ return dict(), prompt_dict_error
411
+
412
+ if prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
413
+ PromptType.custom.name]:
414
+ promptA = prompt_dict.get('promptA', '')
415
+ promptB = prompt_dict('promptB', '')
416
+ PreInstruct = prompt_dict.get('PreInstruct', '')
417
+ PreInput = prompt_dict.get('PreInput', '')
418
+ PreResponse = prompt_dict.get('PreResponse', '')
419
+ terminate_response = prompt_dict.get('terminate_response', None)
420
+ chat_sep = prompt_dict.get('chat_sep', '\n')
421
+ humanstr = prompt_dict.get('humanstr', '')
422
+ botstr = prompt_dict.get('botstr', '')
423
+ elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
424
+ PromptType.plain.name]:
425
+ promptA = promptB = PreInstruct = PreInput = PreResponse = ''
426
+ terminate_response = []
427
+ chat_sep = ''
428
+ humanstr = ''
429
+ botstr = ''
430
+ elif prompt_type == 'simple_instruct':
431
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
432
+ terminate_response = []
433
+ chat_sep = '\n'
434
+ humanstr = ''
435
+ botstr = ''
436
+ elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
437
+ PromptType.instruct.name] + [PromptType.instruct_with_end.value,
438
+ str(PromptType.instruct_with_end.value),
439
+ PromptType.instruct_with_end.name]:
440
+ promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
441
+ chat and reduced) else ''
442
+ promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
443
+ chat and reduced) else ''
444
+
445
+ PreInstruct = """
446
+ ### Instruction:
447
+ """
448
+
449
+ PreInput = """
450
+ ### Input:
451
+ """
452
+
453
+ PreResponse = """
454
+ ### Response:
455
+ """
456
+ if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
457
+ PromptType.instruct_with_end.name]:
458
+ terminate_response = ['### End']
459
+ else:
460
+ terminate_response = None
461
+ chat_sep = '\n'
462
+ humanstr = PreInstruct
463
+ botstr = PreResponse
464
+ elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
465
+ PromptType.quality.name]:
466
+ promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
467
+ chat and reduced) else ''
468
+ promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
469
+ chat and reduced) else ''
470
+
471
+ PreInstruct = """
472
+ ### Instruction:
473
+ """
474
+
475
+ PreInput = """
476
+ ### Input:
477
+ """
478
+
479
+ PreResponse = """
480
+ ### Response:
481
+ """
482
+ terminate_response = None
483
+ chat_sep = '\n'
484
+ humanstr = PreInstruct # first thing human says
485
+ botstr = PreResponse # first thing bot says
486
+ elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
487
+ PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
488
+ str(PromptType.human_bot_orig.value),
489
+ PromptType.human_bot_orig.name]:
490
+ human = '<human>:'
491
+ bot = "<bot>:"
492
+ if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
493
+ PromptType.human_bot.name]:
494
+ preprompt = ''
495
+ else:
496
+ cur_date = time.strftime('%Y-%m-%d')
497
+ cur_time = time.strftime('%H:%M:%S %p %Z')
498
+
499
+ PRE_PROMPT = """\
500
+ Current Date: {}
501
+ Current Time: {}
502
+
503
+ """
504
+ preprompt = PRE_PROMPT.format(cur_date, cur_time)
505
+ start = human
506
+ promptB = promptA = '%s%s ' % (preprompt, start)
507
+
508
+ PreInstruct = ""
509
+
510
+ PreInput = None
511
+
512
+ if reduced:
513
+ # when making context, want it to appear as-if LLM generated, which starts with space after :
514
+ PreResponse = bot + ' '
515
+ else:
516
+ # normally LLM adds space after this, because was how trained.
517
+ # if add space here, non-unique tokenization will often make LLM produce wrong output
518
+ PreResponse = bot
519
+
520
+ terminate_response = [start, PreResponse]
521
+ chat_sep = '\n'
522
+ humanstr = human # tag before human talks
523
+ botstr = bot # tag before bot talks
524
+ elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
525
+ PromptType.dai_faq.name]:
526
+ promptA = ''
527
+ promptB = 'Answer the following Driverless AI question.\n'
528
+
529
+ PreInstruct = """
530
+ ### Driverless AI frequently asked question:
531
+ """
532
+
533
+ PreInput = None
534
+
535
+ PreResponse = """
536
+ ### Driverless AI documentation answer:
537
+ """
538
+ terminate_response = ['\n\n']
539
+ chat_sep = terminate_response
540
+ humanstr = PreInstruct
541
+ botstr = PreResponse
542
+ elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
543
+ PromptType.summarize.name]:
544
+ promptA = promptB = PreInput = ''
545
+ PreInstruct = '## Main Text\n\n'
546
+ PreResponse = '\n\n## Summary\n\n'
547
+ terminate_response = None
548
+ chat_sep = '\n'
549
+ humanstr = PreInstruct
550
+ botstr = PreResponse
551
+ elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
552
+ PromptType.instruct_vicuna.name]:
553
+ promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
554
+ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
555
+ chat and reduced) else ''
556
+
557
+ PreInstruct = """
558
+ ### Human:
559
+ """
560
+
561
+ PreInput = None
562
+
563
+ PreResponse = """
564
+ ### Assistant:
565
+ """
566
+ terminate_response = [
567
+ '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
568
+ chat_sep = '\n'
569
+ humanstr = PreInstruct
570
+ botstr = PreResponse
571
+ elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
572
+ PromptType.prompt_answer.name]:
573
+ preprompt = ''
574
+ prompt_tokens = "<|prompt|>"
575
+ answer_tokens = "<|answer|>"
576
+ start = prompt_tokens
577
+ promptB = promptA = '%s%s' % (preprompt, start)
578
+ PreInstruct = ""
579
+ PreInput = None
580
+ PreResponse = answer_tokens
581
+ eos = '<|endoftext|>' # neox eos
582
+ terminate_response = [start, PreResponse, eos]
583
+ chat_sep = eos
584
+ humanstr = prompt_tokens
585
+ botstr = answer_tokens
586
+ elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
587
+ PromptType.open_assistant.name]:
588
+ # From added_tokens.json
589
+ preprompt = ''
590
+ prompt_tokens = "<|prompter|>"
591
+ answer_tokens = "<|assistant|>"
592
+ start = prompt_tokens
593
+ promptB = promptA = '%s%s' % (preprompt, start)
594
+ PreInstruct = ""
595
+ PreInput = None
596
+ PreResponse = answer_tokens
597
+ pend = "<|prefix_end|>"
598
+ eos = "</s>"
599
+ terminate_response = [start, PreResponse, pend, eos]
600
+ chat_sep = eos
601
+ humanstr = prompt_tokens
602
+ botstr = answer_tokens
603
+ elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
604
+ PromptType.wizard_lm.name]:
605
+ # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
606
+ preprompt = ''
607
+ start = ''
608
+ promptB = promptA = '%s%s' % (preprompt, start)
609
+ PreInstruct = ""
610
+ PreInput = None
611
+ PreResponse = "\n\n### Response\n"
612
+ eos = "</s>"
613
+ terminate_response = [PreResponse, eos]
614
+ chat_sep = eos
615
+ humanstr = promptA
616
+ botstr = PreResponse
617
+ elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
618
+ PromptType.wizard_mega.name]:
619
+ preprompt = ''
620
+ start = ''
621
+ promptB = promptA = '%s%s' % (preprompt, start)
622
+ PreInstruct = """
623
+ ### Instruction:
624
+ """
625
+ PreInput = None
626
+ PreResponse = """
627
+ ### Assistant:
628
+ """
629
+ terminate_response = [PreResponse]
630
+ chat_sep = '\n'
631
+ humanstr = PreInstruct
632
+ botstr = PreResponse
633
+ elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
634
+ PromptType.instruct_vicuna2.name]:
635
+ promptA = promptB = "" if not (
636
+ chat and reduced) else ''
637
+
638
+ PreInstruct = """
639
+ HUMAN:
640
+ """
641
+
642
+ PreInput = None
643
+
644
+ PreResponse = """
645
+ ASSISTANT:
646
+ """
647
+ terminate_response = [
648
+ 'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
649
+ chat_sep = '\n'
650
+ humanstr = PreInstruct
651
+ botstr = PreResponse
652
+ elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
653
+ PromptType.instruct_vicuna3.name]:
654
+ promptA = promptB = "" if not (
655
+ chat and reduced) else ''
656
+
657
+ PreInstruct = """
658
+ ### User:
659
+ """
660
+
661
+ PreInput = None
662
+
663
+ PreResponse = """
664
+ ### Assistant:
665
+ """
666
+ terminate_response = [
667
+ '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
668
+ chat_sep = '\n'
669
+ humanstr = PreInstruct
670
+ botstr = PreResponse
671
+ elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
672
+ PromptType.wizard2.name]:
673
+ # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
674
+ preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
675
+ start = ''
676
+ promptB = promptA = '%s%s' % (preprompt, start)
677
+ PreInstruct = """
678
+ ### Instruction:
679
+ """
680
+ PreInput = None
681
+ PreResponse = """
682
+ ### Response:
683
+ """
684
+ terminate_response = [PreResponse]
685
+ chat_sep = '\n'
686
+ humanstr = PreInstruct
687
+ botstr = PreResponse
688
+ elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
689
+ PromptType.wizard3.name]:
690
+ # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
691
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
692
+ start = ''
693
+ promptB = promptA = '%s%s' % (preprompt, start)
694
+ PreInstruct = """USER: """
695
+ PreInput = None
696
+ PreResponse = """ASSISTANT: """
697
+ terminate_response = [PreResponse]
698
+ chat_sep = '\n'
699
+ humanstr = PreInstruct
700
+ botstr = PreResponse
701
+
702
+ elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
703
+ PromptType.instruct_simple.name]:
704
+ promptA = '' if not (chat and reduced) else ''
705
+ promptB = '' if not (chat and reduced) else ''
706
+
707
+ PreInstruct = """
708
+ ### Instruction:
709
+ """
710
+
711
+ PreInput = """
712
+ ### Input:
713
+ """
714
+
715
+ PreResponse = """
716
+ ### Response:
717
+ """
718
+ terminate_response = None
719
+ chat_sep = '\n'
720
+ humanstr = PreInstruct
721
+ botstr = PreResponse
722
+ else:
723
+ raise RuntimeError("No such prompt_type=%s" % prompt_type)
724
+
725
+ if return_dict:
726
+ return dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
727
+ PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
728
+ humanstr=humanstr, botstr=botstr), ''
729
+ else:
730
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
731
+
732
+
733
+ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
734
+ context = data_point.get('context')
735
+ if context is None:
736
+ context = ''
737
+ instruction = data_point.get('instruction')
738
+ input = data_point.get('input')
739
+ output = data_point.get('output')
740
+ prompt_type = data_point.get('prompt_type', prompt_type)
741
+ prompt_dict = data_point.get('prompt_dict', prompt_dict)
742
+ assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
743
+ promptA, promptB, PreInstruct, PreInput, PreResponse, \
744
+ terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, prompt_dict, chat, context, reduced)
745
+
746
+ prompt = context if not reduced else ''
747
+
748
+ if input and promptA:
749
+ prompt += f"""{promptA}"""
750
+ elif promptB:
751
+ prompt += f"""{promptB}"""
752
+
753
+ if instruction and PreInstruct is not None and input and PreInput is not None:
754
+ prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
755
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
756
+ elif instruction and input and PreInstruct is None and PreInput is not None:
757
+ prompt += f"""{PreInput}{instruction}
758
+ {input}"""
759
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
760
+ elif input and instruction and PreInput is None and PreInstruct is not None:
761
+ prompt += f"""{PreInstruct}{instruction}
762
+ {input}"""
763
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
764
+ elif instruction and PreInstruct is not None:
765
+ prompt += f"""{PreInstruct}{instruction}"""
766
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
767
+ elif input and PreInput is not None:
768
+ prompt += f"""{PreInput}{input}"""
769
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
770
+ elif input and instruction and PreInput is not None:
771
+ prompt += f"""{PreInput}{instruction}{input}"""
772
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
773
+ elif input and instruction and PreInstruct is not None:
774
+ prompt += f"""{PreInstruct}{instruction}{input}"""
775
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
776
+ elif input and instruction:
777
+ # i.e. for simple_instruct
778
+ prompt += f"""{instruction}: {input}"""
779
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
780
+ elif input:
781
+ prompt += f"""{input}"""
782
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
783
+ elif instruction:
784
+ prompt += f"""{instruction}"""
785
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
786
+
787
+ if PreResponse is not None:
788
+ prompt += f"""{PreResponse}"""
789
+ pre_response = PreResponse # Don't use strip
790
+ else:
791
+ pre_response = ''
792
+
793
+ if output:
794
+ prompt += f"""{output}"""
795
+
796
+ return prompt, pre_response, terminate_response, chat_sep
797
+
798
+
799
+ def inject_chatsep(prompt_type, prompt, chat_sep=None):
800
+ if chat_sep:
801
+ # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
802
+ prompt += chat_sep
803
+ return prompt
804
+
805
+
806
+ class Prompter(object):
807
+ def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True,
808
+ allowed_repeat_line_length=10):
809
+ self.prompt_type = prompt_type
810
+ self.prompt_dict = prompt_dict
811
+ data_point = dict(instruction='', input='', output='')
812
+ _, self.pre_response, self.terminate_response, self.chat_sep = \
813
+ generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False)
814
+ self.debug = debug
815
+ self.chat = chat
816
+ self.stream_output = stream_output
817
+ self.repeat_penalty = repeat_penalty
818
+ self.allowed_repeat_line_length = allowed_repeat_line_length
819
+ self.prompt = None
820
+ context = "" # not for chat context
821
+ reduced = False # not for chat context
822
+ self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
823
+ self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
824
+ get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced)
825
+
826
+ def generate_prompt(self, data_point):
827
+ reduced = False
828
+ prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced)
829
+ if self.debug:
830
+ print("prompt: %s" % prompt, flush=True)
831
+ self.prompt = prompt
832
+ return prompt
833
+
834
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
835
+ if isinstance(outputs, str):
836
+ outputs = [outputs]
837
+ if self.debug:
838
+ print("output:\n%s" % '\n\n'.join(outputs), flush=True)
839
+ if prompt is not None:
840
+ self.prompt = prompt
841
+
842
+ def clean_response(response):
843
+ meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
844
+ for word in meaningless_words:
845
+ response = response.replace(word, "")
846
+ if sanitize_bot_response:
847
+ from better_profanity import profanity
848
+ response = profanity.censor(response)
849
+ response = response.strip("\n")
850
+ return response
851
+
852
+ def clean_repeats(response):
853
+ lines = response.split('\n')
854
+ new_lines = []
855
+ [new_lines.append(line) for line in lines if
856
+ line not in new_lines or len(line) < self.allowed_repeat_line_length]
857
+ if self.debug and len(lines) != len(new_lines):
858
+ print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
859
+ response = '\n'.join(new_lines)
860
+ return response
861
+
862
+ multi_output = len(outputs) > 1
863
+
864
+ for oi, output in enumerate(outputs):
865
+ if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
866
+ output = clean_response(output)
867
+ elif prompt is None:
868
+ # then use most basic parsing like pipeline
869
+ if self.botstr in output:
870
+ if self.humanstr:
871
+ output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
872
+ else:
873
+ # i.e. use after bot but only up to next bot
874
+ output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
875
+ else:
876
+ # output = clean_response(output.strip())
877
+ # assume just not printed yet
878
+ output = ""
879
+ else:
880
+ # find first instance of prereponse
881
+ # prompt sometimes has odd characters, that mutate length,
882
+ # so can't go by length alone
883
+ if self.pre_response:
884
+ outputi = output.find(prompt)
885
+ if outputi >= 0:
886
+ output = output[outputi + len(prompt):]
887
+ allow_terminate = True
888
+ else:
889
+ # subtraction is risky due to space offsets sometimes, so only do if necessary
890
+ output = output[len(prompt) - len(self.pre_response):]
891
+ # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
892
+ if self.pre_response in output:
893
+ output = output.split(self.pre_response)[1]
894
+ allow_terminate = True
895
+ else:
896
+ if output:
897
+ print("Failure of parsing or not enough output yet: %s" % output, flush=True)
898
+ allow_terminate = False
899
+ else:
900
+ allow_terminate = True
901
+ output = output[len(prompt):]
902
+ # clean after subtract prompt out, so correct removal of pre_response
903
+ output = clean_response(output).strip()
904
+ if self.repeat_penalty:
905
+ output = clean_repeats(output).strip()
906
+ if self.terminate_response and allow_terminate:
907
+ finds = []
908
+ for term in self.terminate_response:
909
+ finds.append(output.find(term))
910
+ finds = [x for x in finds if x >= 0]
911
+ if len(finds) > 0:
912
+ termi = finds[0]
913
+ output = output[:termi].strip()
914
+ else:
915
+ output = output.strip()
916
+ else:
917
+ output = output.strip()
918
+ if multi_output:
919
+ # prefix with output counter
920
+ output = "\n=========== Output %d\n\n" % (1 + oi) + output
921
+ if oi > 0:
922
+ # post fix outputs with seperator
923
+ output += '\n'
924
+ outputs[oi] = output
925
+ # join all outputs, only one extra new line between outputs
926
+ output = '\n'.join(outputs)
927
+ if self.debug:
928
+ print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
929
+ return output