BAAI
/

ldwang commited on
Commit
eb27b1a
·
1 Parent(s): 973ff4c

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +436 -0
predict.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copied from https://github.com/lm-sys/FastChat.
3
+ Later we will contribute our changes into it.
4
+ """
5
+ import dataclasses
6
+ from enum import auto, IntEnum
7
+ from typing import List, Any, Dict
8
+ import math
9
+ from typing import List, Optional, Tuple, Union
10
+ import random
11
+ import numpy as np
12
+
13
+ import torch
14
+ import torch.utils.checkpoint
15
+ from torch import nn
16
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
17
+
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
22
+ from transformers import (
23
+ LogitsProcessorList,
24
+ MinLengthLogitsProcessor,
25
+ TopKLogitsWarper,
26
+ TemperatureLogitsWarper,
27
+ TopPLogitsWarper,
28
+ StoppingCriteriaList,
29
+ MaxLengthCriteria,
30
+ BitsAndBytesConfig,
31
+ )
32
+
33
+
34
+
35
+ class SeparatorStyle(IntEnum):
36
+ """Separator styles."""
37
+
38
+ ADD_COLON_SINGLE = auto()
39
+ ADD_COLON_TWO = auto()
40
+ ADD_COLON_SPACE_SINGLE = auto()
41
+ NO_COLON_SINGLE = auto()
42
+ NO_COLON_TWO = auto()
43
+ ADD_NEW_LINE_SINGLE = auto()
44
+
45
+
46
+ @dataclasses.dataclass
47
+ class Conversation:
48
+ """A class that manages prompt templates and keeps all conversation history."""
49
+
50
+ # The name of this template
51
+ name: str
52
+ # The template of the system prompt
53
+ system_template: str = "{system_message}"
54
+ # The system message
55
+ system_message: str = ""
56
+ # The names of two roles
57
+ roles: List[str] = (("USER", "ASSISTANT"),)
58
+ # All messages. Each item is (role, message).
59
+ messages: List[List[str]] = ()
60
+ # The number of few shot examples
61
+ offset: int = 0
62
+ # The separator style and configurations
63
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
64
+ sep: str = "\n"
65
+ sep2: str = None
66
+ # Stop criteria (the default one is EOS token)
67
+ stop_str: str = None
68
+ # Stops generation if meeting any token in this list
69
+ stop_token_ids: List[int] = None
70
+
71
+ def get_prompt(self) -> str:
72
+ """Get the prompt for generation."""
73
+ system_prompt = self.system_template.format(system_message=self.system_message)
74
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
75
+ ret = system_prompt + self.sep
76
+ for role, message in self.messages:
77
+ if message:
78
+ ret += role + ": " + message + self.sep
79
+ else:
80
+ ret += role + ":"
81
+ return ret
82
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
83
+ seps = [self.sep, self.sep2]
84
+ ret = system_prompt + seps[0]
85
+ for i, (role, message) in enumerate(self.messages):
86
+ if message:
87
+ ret += role + ": " + message + seps[i % 2]
88
+ else:
89
+ ret += role + ":"
90
+ return ret
91
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
92
+ ret = system_prompt + self.sep
93
+ for role, message in self.messages:
94
+ if message:
95
+ ret += role + ": " + message + self.sep
96
+ else:
97
+ ret += role + ": " # must be end with a space
98
+ return ret
99
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
100
+ ret = "" if system_prompt == "" else system_prompt + self.sep
101
+ for role, message in self.messages:
102
+ if message:
103
+ ret += role + "\n" + message + self.sep
104
+ else:
105
+ ret += role + "\n"
106
+ return ret
107
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
108
+ ret = system_prompt
109
+ for role, message in self.messages:
110
+ if message:
111
+ ret += role + message + self.sep
112
+ else:
113
+ ret += role
114
+ return ret
115
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
116
+ seps = [self.sep, self.sep2]
117
+ ret = system_prompt
118
+ for i, (role, message) in enumerate(self.messages):
119
+ if message:
120
+ ret += role + message + seps[i % 2]
121
+ else:
122
+ ret += role
123
+ return ret
124
+
125
+ def set_system_message(self, system_message: str):
126
+ """Set the system message."""
127
+ self.system_message = system_message
128
+
129
+ def append_message(self, role: str, message: str):
130
+ """Append a new message."""
131
+ self.messages.append([role, message])
132
+
133
+ def update_last_message(self, message: str):
134
+ """Update the last output.
135
+
136
+ The last message is typically set to be None when constructing the prompt,
137
+ so we need to update it in-place after getting the response from a model.
138
+ """
139
+ self.messages[-1][1] = message
140
+
141
+ def copy(self):
142
+ return Conversation(
143
+ name=self.name,
144
+ system_template=self.system_template,
145
+ system_message=self.system_message,
146
+ roles=self.roles,
147
+ messages=[[x, y] for x, y in self.messages],
148
+ offset=self.offset,
149
+ sep_style=self.sep_style,
150
+ sep=self.sep,
151
+ sep2=self.sep2,
152
+ stop_str=self.stop_str,
153
+ stop_token_ids=self.stop_token_ids,
154
+ )
155
+
156
+ def dict(self):
157
+ return {
158
+ "template_name": self.name,
159
+ "system_message": self.system_message,
160
+ "roles": self.roles,
161
+ "messages": self.messages,
162
+ "offset": self.offset,
163
+ }
164
+
165
+
166
+ # A global registry for all conversation templates
167
+ conv_templates: Dict[str, Conversation] = {}
168
+
169
+
170
+ def register_conv_template(template: Conversation, override: bool = False):
171
+ """Register a new conversation template."""
172
+ if not override:
173
+ assert (
174
+ template.name not in conv_templates
175
+ ), f"{template.name} has been registered."
176
+
177
+ conv_templates[template.name] = template
178
+
179
+
180
+ def get_conv_template(name: str) -> Conversation:
181
+ """Get a conversation template."""
182
+ return conv_templates[name].copy()
183
+
184
+ def get_conversation_template(model_path: str) -> Conversation:
185
+ """Get the default conversation template."""
186
+ if "aquila-v1" in model_path:
187
+ return get_conv_template("aquila-v1")
188
+ elif "aquila-chat" in model_path:
189
+ return get_conv_template("aquila-chat")
190
+ elif "aquila-legacy" in model_path:
191
+ return get_conv_template("aquila-legacy")
192
+ else:
193
+ return get_conv_template("aquila")
194
+
195
+ # AquilaChat default template
196
+ # source: https://github.com/FlagAI-Open/FlagAI/blob/master/examples/Aquila/Aquila-chat/cyg_conversation.py
197
+ register_conv_template(
198
+ Conversation(
199
+ name="aquila-chat",
200
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
201
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
202
+ roles=("Human", "Assistant", "System"),
203
+ messages=(),
204
+ offset=0,
205
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
206
+ sep="###",
207
+ sep2="",
208
+ stop_str=["###", "</s>", "[UNK]"],
209
+ )
210
+ )
211
+
212
+ register_conv_template(
213
+ Conversation(
214
+ name="aquila-legacy",
215
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
216
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
217
+ roles=("### Human: ", "### Assistant: ", "System"),
218
+ messages=(),
219
+ offset=0,
220
+ sep_style=SeparatorStyle.NO_COLON_TWO,
221
+ sep="\n",
222
+ sep2="</s>",
223
+ stop_str=["</s>", "[UNK]"],
224
+ )
225
+ )
226
+
227
+ register_conv_template(
228
+ Conversation(
229
+ name="aquila",
230
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
231
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
232
+ roles=("Human", "Assistant", "System"),
233
+ messages=(),
234
+ offset=0,
235
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
236
+ sep="###",
237
+ sep2="</s>",
238
+ stop_str=["</s>", "[UNK]"],
239
+ )
240
+ )
241
+
242
+ register_conv_template(
243
+ Conversation(
244
+ name="aquila-v1",
245
+ roles=("<|startofpiece|>", "<|endofpiece|>", ""),
246
+ messages=(),
247
+ offset=0,
248
+ sep_style=SeparatorStyle.NO_COLON_TWO,
249
+ sep="",
250
+ sep2="</s>",
251
+ stop_str=["</s>", "<|endoftext|>"],
252
+ )
253
+ )
254
+
255
+
256
+ if __name__ == "__main__":
257
+ print("aquila template:")
258
+ conv = get_conv_template("aquila")
259
+ conv.append_message(conv.roles[0], "Hello!")
260
+ conv.append_message(conv.roles[1], "Hi!")
261
+ conv.append_message(conv.roles[0], "How are you?")
262
+ conv.append_message(conv.roles[1], None)
263
+ print(conv.get_prompt())
264
+
265
+ print("\n")
266
+
267
+ print("aquila-chat template:")
268
+ conv = get_conv_template("aquila-chat")
269
+ conv.append_message(conv.roles[0], "Hello!")
270
+ conv.append_message(conv.roles[1], "Hi!")
271
+ conv.append_message(conv.roles[0], "How are you?")
272
+ conv.append_message(conv.roles[1], None)
273
+ print(conv.get_prompt())
274
+
275
+ print("\n")
276
+
277
+ print("aquila-v1 template:")
278
+ conv = get_conv_template("aquila-v1")
279
+ conv.append_message(conv.roles[0], "Hello!")
280
+ conv.append_message(conv.roles[1], "Hi!")
281
+ conv.append_message(conv.roles[0], "How are you?")
282
+ conv.append_message(conv.roles[1], None)
283
+ print(conv.get_prompt())
284
+
285
+ print("\n")
286
+
287
+ print("aquila-legacy template:")
288
+ conv = get_conv_template("aquila-legacy")
289
+ conv.append_message(conv.roles[0], "Hello!")
290
+ conv.append_message(conv.roles[1], "Hi!")
291
+ conv.append_message(conv.roles[0], "How are you?")
292
+ conv.append_message(conv.roles[1], None)
293
+ print(conv.get_prompt())
294
+
295
+ print("\n")
296
+
297
+ def set_random_seed(seed):
298
+ """Set random seed for reproducability."""
299
+ if seed is not None and seed > 0:
300
+ random.seed(seed)
301
+ np.random.seed(seed)
302
+ torch.manual_seed(seed)
303
+
304
+ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token, convo_template="aquila-chat"):
305
+ # aquila-chat as default
306
+ conv = get_conv_template(convo_template)
307
+
308
+ conv.append_message(conv.roles[1], None)
309
+ conv.append_message(conv.roles[0], text)
310
+
311
+ example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
312
+
313
+ while(len(history) > 0 and (len(example) < max_token)):
314
+ tmp = history.pop()
315
+ if tmp[0] == 'ASSISTANT':
316
+ conv.append_message(conv.roles[1], tmp[1])
317
+ else:
318
+ conv.append_message(conv.roles[0], tmp[1])
319
+ example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
320
+
321
+ if len(example) >= max_token:
322
+ conv.messages.pop()
323
+ conv.messages = conv.messages[::-1]
324
+ print('model in:', conv.get_prompt())
325
+ example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
326
+
327
+ return example
328
+
329
+ def predict(model, text, tokenizer=None,
330
+ max_gen_len=200, top_p=0.95,
331
+ seed=1234, topk=100,
332
+ temperature=0.9,
333
+ sft=True, convo_template = "",
334
+ device = "cuda",
335
+ model_name="AquilaChat2-7B",
336
+ **kwargs):
337
+
338
+ vocab = tokenizer.get_vocab()
339
+
340
+ id2word = {v:k for k, v in vocab.items()}
341
+
342
+
343
+ template_map = {"AquilaChat2-7B": "aquila-v1",
344
+ "AquilaChat2-34B": "aquila-legacy",
345
+ "AquilaChat2-7B-16K": "aquila",
346
+ "AquilaChat2-34B-16K": "aquila-v1"}
347
+ if not convo_template:
348
+ convo_template=template_map.get(model_name, "aquila-chat")
349
+
350
+ set_random_seed(seed)
351
+ if temperature == 0:
352
+ topk = 1
353
+ temperature = 1.0
354
+ if sft:
355
+ tokens = covert_prompt_to_input_ids_with_history(text, history=[], tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
356
+ tokens = torch.tensor(tokens)[None,].to(device)
357
+ else :
358
+ tokens = tokenizer.encode_plus(text)["input_ids"]
359
+ print(tokenizer.decode(tokens))
360
+ tokens = torch.tensor(tokens)[None,].to(device)
361
+ input_length = len(tokens[0])
362
+ with torch.no_grad():
363
+
364
+ # instantiate logits processors
365
+ logits_processor = LogitsProcessorList(
366
+ [
367
+ MinLengthLogitsProcessor(1, eos_token_id=100007),
368
+ ]
369
+ )
370
+ # instantiate logits processors
371
+ logits_warper = LogitsProcessorList(
372
+ [
373
+ TopPLogitsWarper(top_p),
374
+ TopKLogitsWarper(topk),
375
+ TemperatureLogitsWarper(temperature),
376
+
377
+ ]
378
+ )
379
+
380
+ stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=input_length + max_gen_len)])
381
+ out = model.sample(
382
+ tokens,
383
+ logits_processor=logits_processor,
384
+ logits_warper=logits_warper,
385
+ stopping_criteria=stopping_criteria,
386
+ return_dict_in_generate=True,
387
+ output_scores=True,
388
+ )
389
+
390
+
391
+ # print(out)
392
+ out_ids = out["sequences"][0][input_length:].cpu().numpy()
393
+
394
+ out_scores = out["scores"]
395
+
396
+ out_scores = torch.cat(out_scores, dim=0)
397
+ out_scores = torch.nn.functional.softmax(out_scores, dim=-1).cpu().numpy()
398
+
399
+ probs = []
400
+ for i in range(len(out_ids)):
401
+ probs.append(float(out_scores[i][out_ids[i]]))
402
+
403
+ # print(f"probs is {probs}")
404
+
405
+ convert_tokens = []
406
+ for t in out_ids:
407
+ if t == 100006:
408
+ convert_tokens.append("[CLS]")
409
+ else :
410
+ convert_tokens.append(id2word.get(t, "[unkonwn_token]"))
411
+
412
+ out_text = tokenizer.decode(out_ids.tolist())
413
+
414
+
415
+ out = out_text
416
+
417
+ if "[UNK]" in out:
418
+ special_index = out.index("[UNK]")
419
+ out = out[:special_index]
420
+ token_length = len(tokenizer.encode_plus(out)["input_ids"])
421
+ convert_tokens = convert_tokens[:token_length]
422
+ probs = probs[:token_length]
423
+
424
+ if "</s>" in out:
425
+ special_index = out.index("</s>")
426
+ out = out[: special_index]
427
+ token_length = len(tokenizer.encode_plus(out)["input_ids"])
428
+ convert_tokens = convert_tokens[:token_length]
429
+ probs = probs[:token_length]
430
+
431
+ if len(out) > 0 and out[0] == " ":
432
+ out = out[1:]
433
+
434
+ convert_tokens = convert_tokens[1:]
435
+ probs = probs[1:]
436
+ return out