JosephusCheung commited on
Commit
50e6685
1 Parent(s): 5e3a6ee

Upload 3 files

Browse files
eval/evaluate_chatml_ceval.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+ import re
8
+ from thefuzz import process
9
+ from typing import List
10
+ from tqdm import tqdm
11
+ from transformers.trainer_utils import set_seed
12
+
13
+ from typing import Tuple, List, Union, Iterable
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from transformers import PreTrainedTokenizer
19
+ from transformers import logging
20
+ from transformers.generation import LogitsProcessor
21
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
22
+ HistoryType = List[Tuple[str, str]]
23
+ TokensType = List[int]
24
+ BatchTokensType = List[List[int]]
25
+
26
+ def make_context(
27
+ tokenizer: PreTrainedTokenizer,
28
+ query: str,
29
+ history: List[Tuple[str, str]] = None,
30
+ system: str = "",
31
+ max_window_size: int = 6144,
32
+ chat_format: str = "chatml",
33
+ ):
34
+ if history is None:
35
+ history = []
36
+
37
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
38
+ im_start_tokens = [tokenizer.im_start_id]
39
+ im_end_tokens = [tokenizer.im_end_id]
40
+ nl_tokens = tokenizer.encode("\n")
41
+
42
+ def _tokenize_str(role, content):
43
+ return f"{role}\n{content}", tokenizer.encode(
44
+ role
45
+ ) + nl_tokens + tokenizer.encode(content)
46
+
47
+ system_text, system_tokens_part = _tokenize_str("system", system)
48
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
49
+
50
+ raw_text = ""
51
+ context_tokens = []
52
+
53
+ for turn_query, turn_response in reversed(history):
54
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
55
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
56
+ response_text, response_tokens_part = _tokenize_str(
57
+ "assistant", turn_response
58
+ )
59
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
60
+
61
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
62
+ prev_chat = (
63
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
64
+ )
65
+
66
+ current_context_size = (
67
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
68
+ )
69
+ if current_context_size < max_window_size:
70
+ context_tokens = next_context_tokens + context_tokens
71
+ raw_text = prev_chat + raw_text
72
+ else:
73
+ break
74
+
75
+ context_tokens = system_tokens + context_tokens
76
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
77
+ context_tokens += (
78
+ nl_tokens
79
+ + im_start_tokens
80
+ + _tokenize_str("user", query)[1]
81
+ + im_end_tokens
82
+ + nl_tokens
83
+ + im_start_tokens
84
+ + tokenizer.encode("assistant")
85
+ + nl_tokens
86
+ )
87
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
88
+
89
+ return raw_text, context_tokens
90
+
91
+ def chat(
92
+ model,
93
+ tokenizer: PreTrainedTokenizer,
94
+ query: str,
95
+ history: Optional[HistoryType],
96
+ system: str = "You are a helpful assistant.",
97
+ append_history: bool = True
98
+ ) -> Tuple[str, HistoryType]:
99
+
100
+
101
+ if history is None:
102
+ history = []
103
+
104
+ raw_text, context_tokens = make_context(
105
+ tokenizer,
106
+ query,
107
+ history=history,
108
+ system=system,
109
+ max_window_size=6144,
110
+ chat_format = "chatml",
111
+ )
112
+
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ input_ids = torch.tensor([context_tokens]).cuda()
115
+ outputs = model.generate(
116
+ input_ids,
117
+ # stop_words_ids = stop_words_ids,
118
+ return_dict_in_generate = False,
119
+ )
120
+
121
+ response = decode_tokens(
122
+ outputs[0],
123
+ tokenizer,
124
+ raw_text_len=len(raw_text),
125
+ context_length=len(context_tokens),
126
+ chat_format='chatml',
127
+ verbose=False,
128
+ )
129
+
130
+ if append_history:
131
+ history.append((query, response))
132
+
133
+ return response, history
134
+
135
+ def decode_tokens(
136
+ tokens: Union[torch.LongTensor, TokensType],
137
+ tokenizer: PreTrainedTokenizer,
138
+ raw_text_len: int,
139
+ context_length: int,
140
+ chat_format: str = "chatml",
141
+ verbose: bool = False,
142
+ return_end_reason: bool = False,
143
+ ) -> str:
144
+ if torch.is_tensor(tokens):
145
+ tokens = tokens.cpu().numpy().tolist()
146
+
147
+
148
+ return _decode_chatml(
149
+ tokens,
150
+ stop_words=[],
151
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
152
+ tokenizer=tokenizer,
153
+ raw_text_len=raw_text_len,
154
+ context_length=context_length,
155
+ verbose=verbose,
156
+ return_end_reason=return_end_reason,
157
+ )
158
+
159
+
160
+ def _decode_chatml(
161
+ tokens: List[int],
162
+ *,
163
+ stop_words: List[str],
164
+ eod_token_ids: List[int],
165
+ tokenizer: PreTrainedTokenizer,
166
+ raw_text_len: int,
167
+ context_length: int,
168
+ verbose: bool = False,
169
+ return_end_reason: bool = False,
170
+ chat_format = "chatml",
171
+ ):
172
+ end_reason = f"Gen length {len(tokens)}"
173
+ eod_token_idx = context_length
174
+ for eod_token_idx in range(context_length, len(tokens)):
175
+ if tokens[eod_token_idx] in eod_token_ids:
176
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
177
+ break
178
+
179
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx])[raw_text_len:]
180
+ if verbose:
181
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens)[raw_text_len:])
182
+ print("\nRaw Generate:", trim_decode_tokens)
183
+ print("\nEnd Reason:", end_reason)
184
+ for stop_word in stop_words:
185
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
186
+ trim_decode_tokens = trim_decode_tokens.strip()
187
+ if verbose:
188
+ print("\nGenerate:", trim_decode_tokens)
189
+
190
+ if return_end_reason:
191
+ return trim_decode_tokens, end_reason
192
+ else:
193
+ return trim_decode_tokens
194
+
195
+
196
+
197
+ def load_models_tokenizer(args):
198
+ from transformers import AutoModelForCausalLM, AutoTokenizer
199
+ from transformers.generation import GenerationConfig
200
+
201
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
202
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
203
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
204
+ model.generation_config.do_sample = False # use greedy decoding
205
+ return model, tokenizer
206
+
207
+
208
+ def process_before_extraction(gen, question, choice_dict):
209
+ # Example Prompt:
210
+ # 关于传输层的面向连接服务的特性是____。
211
+ # A. 既不保证可靠,也不保证按序交付
212
+ # B. 不保证可靠,但保证按序交付
213
+ # C. 保证可靠,但不保证按序交付
214
+ # D. 既保证可靠,也保证按序交付
215
+ # Example Model Output:
216
+ # 关于传输层的面向连接服务的特性是既保证可靠,也保证按序交付
217
+ # Processed Output:
218
+ # 答案是D
219
+
220
+ question_split = question.rstrip("。").split("。")[-1].split("_")
221
+
222
+ # replacing the question
223
+ if len(question_split[0].strip()) > 4:
224
+ gen = gen.replace(question_split[0], "答案是")
225
+ if len(question_split[-1].strip()) > 4:
226
+ gen = gen.replace(question_split[-1], "")
227
+
228
+ # replace the choice by letter in the generated sentence
229
+ # from longest one to shortest one
230
+ for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
231
+ gen = gen.replace(val.rstrip("。"), key)
232
+ return gen
233
+
234
+
235
+ def count_substr(gen, pattern):
236
+ return len(re.findall(pattern, gen))
237
+
238
+
239
+ def extract_choice(gen, prompt, choice_list):
240
+ # 答案是A | 选项是A | 应该选A选项
241
+ res = re.search(
242
+ r"(?:(?:选|选择|选定)[::]?\s*|(?:(?:答案|选项)(?![^ABCD]{0,10}?(?:不|非)[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?)(A|B|C|D)(?:选项)?(?:\)|。|\.|,|,|.|、|A|B|C|D|$|:|:|\)|))",
243
+ gen,
244
+ )
245
+
246
+ # A选项正确 | A选项符合题意
247
+ if res is None:
248
+ res = re.search(
249
+ r"(A|B|C|D)(?:选?项)?(?![^ABCD]{0,4}?(?:不|非)[^ABCD]{0,4}?(?:正确|对[的,。:]|符合))[^ABCD]{0,4}?(?:正确|对[的,。:]|符合)",
250
+ gen,
251
+ )
252
+
253
+ # 直接输出 A
254
+ if res is None:
255
+ res = re.search(r"^[\((]?(A|B|C|D)(?:。|\)|)|\.|,|,|.|:|:|$)", gen)
256
+
257
+ # 获取第一个出现的字母
258
+ if res is None:
259
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
260
+
261
+ if res is None:
262
+ return choices[choice_list.index(process.extractOne(gen, choice_list)[0])]
263
+ return res.group(1)
264
+
265
+
266
+ def format_example(line):
267
+ example = line["question"] + "\n\n"
268
+ for choice in choices:
269
+ example += f'{choice}. {line[f"{choice}"]}\n'
270
+ return example
271
+
272
+
273
+ def extract_answer(response, row):
274
+ prompt = row["question"]
275
+ gen = process_before_extraction(
276
+ response, prompt, {choice: row[choice] for choice in choices}
277
+ )
278
+ if not isinstance(prompt, str):
279
+ prompt = prompt[0]
280
+ pred = extract_choice(gen, prompt, [row[choice] for choice in choices])
281
+ return pred
282
+
283
+
284
+ @torch.no_grad()
285
+ def eval_subject(
286
+ model,
287
+ tokenizer,
288
+ subject_name,
289
+ test_df,
290
+ save_result_dir=None,
291
+ overwrite=False,
292
+ **kwargs
293
+ ):
294
+ result_path = os.path.join(save_result_dir, f"{subject_name}_result.csv")
295
+ if not overwrite and os.path.exists(result_path):
296
+ print(f"{result_path} existed, skip!")
297
+ score = []
298
+ for (_, datarow), (_, resultrow) in zip(
299
+ test_df.iterrows(), pd.read_csv(result_path).iterrows()
300
+ ):
301
+ pred = extract_answer(resultrow["model_response"], datarow)
302
+ correct = 1 if pred == datarow["answer"] else 0
303
+ score.append(correct)
304
+ correct_ratio = 100 * sum(score) / len(score)
305
+ return correct_ratio
306
+
307
+ responses = []
308
+ result = []
309
+ score = []
310
+
311
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
312
+ question = format_example(row)
313
+
314
+ response, _ = chat(
315
+ model,
316
+ tokenizer,
317
+ question,
318
+ history=None,
319
+ )
320
+ print(question)
321
+ print(response)
322
+ pred = extract_answer(response, row)
323
+ print(pred)
324
+ print("======================")
325
+
326
+ if "answer" in row:
327
+ correct = 1 if pred == row["answer"] else 0
328
+ score.append(correct)
329
+ if args.debug:
330
+ print(f'{question} pred: {pred} ref: {row["answer"]}')
331
+ responses.append(response)
332
+ result.append(pred)
333
+
334
+ if score:
335
+ correct_ratio = 100 * sum(score) / len(score)
336
+ if args.debug:
337
+ print(subject_name, correct_ratio)
338
+ else:
339
+ correct_ratio = 0
340
+ if save_result_dir:
341
+ test_df["model_response"] = responses
342
+ test_df["model_output"] = result
343
+ if score:
344
+ test_df["correctness"] = score
345
+ os.makedirs(save_result_dir, exist_ok=True)
346
+ test_df.to_csv(result_path, encoding="utf-8", index=False)
347
+
348
+ return correct_ratio
349
+
350
+
351
+ def cal_ceval(res):
352
+ acc_sum_dict = dict()
353
+ acc_norm_sum_dict = dict()
354
+ cnt_dict = dict()
355
+ acc_sum = 0.0
356
+ cnt = 0
357
+ hard_cnt = 0
358
+ hard_acc_sum = 0.0
359
+ for tt in res.keys():
360
+ name = tt.split("-")[-1]
361
+ acc_sum += float(res[tt])
362
+ cnt += 1
363
+ class_ = TASK_NAME_MAPPING[name][2]
364
+ if class_ not in acc_sum_dict:
365
+ acc_sum_dict[class_] = 0.0
366
+ acc_norm_sum_dict[class_] = 0.0
367
+ cnt_dict[class_] = 0.0
368
+ if name in hard_list:
369
+ hard_cnt += 1
370
+ hard_acc_sum += float(res[tt])
371
+ acc_sum_dict[class_] += float(res[tt])
372
+ cnt_dict[class_] += 1
373
+ print("\n\n\n")
374
+ for k in ["STEM", "Social Science", "Humanities", "Other"]:
375
+ if k in cnt_dict:
376
+ print("%s acc: %.2f " % (k, acc_sum_dict[k] / cnt_dict[k]))
377
+ if hard_cnt > 0:
378
+ print("Hard acc:%.2f " % (hard_acc_sum / hard_cnt))
379
+ print("AVERAGE acc:%.2f " % (acc_sum / cnt))
380
+
381
+
382
+ TASK_NAME_MAPPING = {
383
+ "computer_network": ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"],
384
+ "operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
385
+ "computer_architecture": [
386
+ "Computer Architecture",
387
+ "\u8ba1\u7b97\u673a\u7ec4\u6210",
388
+ "STEM",
389
+ ],
390
+ "college_programming": ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"],
391
+ "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
392
+ "college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
393
+ "advanced_mathematics": [
394
+ "Advanced Mathematics",
395
+ "\u9ad8\u7b49\u6570\u5b66",
396
+ "STEM",
397
+ ],
398
+ "probability_and_statistics": [
399
+ "Probability and Statistics",
400
+ "\u6982\u7387\u7edf\u8ba1",
401
+ "STEM",
402
+ ],
403
+ "discrete_mathematics": [
404
+ "Discrete Mathematics",
405
+ "\u79bb\u6563\u6570\u5b66",
406
+ "STEM",
407
+ ],
408
+ "electrical_engineer": [
409
+ "Electrical Engineer",
410
+ "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08",
411
+ "STEM",
412
+ ],
413
+ "metrology_engineer": [
414
+ "Metrology Engineer",
415
+ "\u6ce8\u518c\u8ba1\u91cf\u5e08",
416
+ "STEM",
417
+ ],
418
+ "high_school_mathematics": [
419
+ "High School Mathematics",
420
+ "\u9ad8\u4e2d\u6570\u5b66",
421
+ "STEM",
422
+ ],
423
+ "high_school_physics": ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"],
424
+ "high_school_chemistry": [
425
+ "High School Chemistry",
426
+ "\u9ad8\u4e2d\u5316\u5b66",
427
+ "STEM",
428
+ ],
429
+ "high_school_biology": ["High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"],
430
+ "middle_school_mathematics": [
431
+ "Middle School Mathematics",
432
+ "\u521d\u4e2d\u6570\u5b66",
433
+ "STEM",
434
+ ],
435
+ "middle_school_biology": [
436
+ "Middle School Biology",
437
+ "\u521d\u4e2d\u751f\u7269",
438
+ "STEM",
439
+ ],
440
+ "middle_school_physics": [
441
+ "Middle School Physics",
442
+ "\u521d\u4e2d\u7269\u7406",
443
+ "STEM",
444
+ ],
445
+ "middle_school_chemistry": [
446
+ "Middle School Chemistry",
447
+ "\u521d\u4e2d\u5316\u5b66",
448
+ "STEM",
449
+ ],
450
+ "veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"],
451
+ "college_economics": [
452
+ "College Economics",
453
+ "\u5927\u5b66\u7ecf\u6d4e\u5b66",
454
+ "Social Science",
455
+ ],
456
+ "business_administration": [
457
+ "Business Administration",
458
+ "\u5de5\u5546\u7ba1\u7406",
459
+ "Social Science",
460
+ ],
461
+ "marxism": [
462
+ "Marxism",
463
+ "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406",
464
+ "Social Science",
465
+ ],
466
+ "mao_zedong_thought": [
467
+ "Mao Zedong Thought",
468
+ "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba",
469
+ "Social Science",
470
+ ],
471
+ "education_science": ["Education Science", "\u6559\u80b2\u5b66", "Social Science"],
472
+ "teacher_qualification": [
473
+ "Teacher Qualification",
474
+ "\u6559\u5e08\u8d44\u683c",
475
+ "Social Science",
476
+ ],
477
+ "high_school_politics": [
478
+ "High School Politics",
479
+ "\u9ad8\u4e2d\u653f\u6cbb",
480
+ "Social Science",
481
+ ],
482
+ "high_school_geography": [
483
+ "High School Geography",
484
+ "\u9ad8\u4e2d\u5730\u7406",
485
+ "Social Science",
486
+ ],
487
+ "middle_school_politics": [
488
+ "Middle School Politics",
489
+ "\u521d\u4e2d\u653f\u6cbb",
490
+ "Social Science",
491
+ ],
492
+ "middle_school_geography": [
493
+ "Middle School Geography",
494
+ "\u521d\u4e2d\u5730\u7406",
495
+ "Social Science",
496
+ ],
497
+ "modern_chinese_history": [
498
+ "Modern Chinese History",
499
+ "\u8fd1\u4ee3\u53f2\u7eb2\u8981",
500
+ "Humanities",
501
+ ],
502
+ "ideological_and_moral_cultivation": [
503
+ "Ideological and Moral Cultivation",
504
+ "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840",
505
+ "Humanities",
506
+ ],
507
+ "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
508
+ "law": ["Law", "\u6cd5\u5b66", "Humanities"],
509
+ "chinese_language_and_literature": [
510
+ "Chinese Language and Literature",
511
+ "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66",
512
+ "Humanities",
513
+ ],
514
+ "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
515
+ "professional_tour_guide": [
516
+ "Professional Tour Guide",
517
+ "\u5bfc\u6e38\u8d44\u683c",
518
+ "Humanities",
519
+ ],
520
+ "legal_professional": [
521
+ "Legal Professional",
522
+ "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c",
523
+ "Humanities",
524
+ ],
525
+ "high_school_chinese": [
526
+ "High School Chinese",
527
+ "\u9ad8\u4e2d\u8bed\u6587",
528
+ "Humanities",
529
+ ],
530
+ "high_school_history": [
531
+ "High School History",
532
+ "\u9ad8\u4e2d\u5386\u53f2",
533
+ "Humanities",
534
+ ],
535
+ "middle_school_history": [
536
+ "Middle School History",
537
+ "\u521d\u4e2d\u5386\u53f2",
538
+ "Humanities",
539
+ ],
540
+ "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
541
+ "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
542
+ "plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"],
543
+ "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
544
+ "clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"],
545
+ "urban_and_rural_planner": [
546
+ "Urban and Rural Planner",
547
+ "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08",
548
+ "Other",
549
+ ],
550
+ "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
551
+ "fire_engineer": [
552
+ "Fire Engineer",
553
+ "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08",
554
+ "Other",
555
+ ],
556
+ "environmental_impact_assessment_engineer": [
557
+ "Environmental Impact Assessment Engineer",
558
+ "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08",
559
+ "Other",
560
+ ],
561
+ "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
562
+ "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"],
563
+ }
564
+ hard_list = [
565
+ "advanced_mathematics",
566
+ "discrete_mathematics",
567
+ "probability_and_statistics",
568
+ "college_physics",
569
+ "college_chemistry",
570
+ "high_school_mathematics",
571
+ "high_school_physics",
572
+ "high_school_chemistry",
573
+ ]
574
+ choices = ["A", "B", "C", "D"]
575
+
576
+
577
+ def main(args):
578
+ print("loading model weights")
579
+ if args.checkpoint_path:
580
+ model, tokenizer = load_models_tokenizer(args)
581
+ else:
582
+ model, tokenizer = None, None
583
+ print("model loaded")
584
+ dev_result = {}
585
+ for subject_name in tqdm(TASK_NAME_MAPPING.keys()):
586
+ val_file_path = os.path.join(
587
+ args.eval_data_path, "val", f"{subject_name}_val.csv"
588
+ )
589
+ val_df = pd.read_csv(val_file_path)
590
+
591
+ score = eval_subject(
592
+ model,
593
+ tokenizer,
594
+ subject_name,
595
+ val_df,
596
+ save_result_dir="outs_chat/ceval_eval_result",
597
+ overwrite=args.overwrite,
598
+ )
599
+ dev_result[subject_name] = score
600
+ cal_ceval(dev_result)
601
+
602
+
603
+ if __name__ == "__main__":
604
+ parser = argparse.ArgumentParser(description="Test HF checkpoint.")
605
+ parser.add_argument(
606
+ "-c",
607
+ "--checkpoint-path",
608
+ type=str,
609
+ help="Checkpoint path",
610
+ default="Qwen/Qwen-7B-Chat",
611
+ )
612
+ parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
613
+
614
+ # Provide extra arguments required for tasks
615
+ group = parser.add_argument_group(title="Evaluation options")
616
+ group.add_argument(
617
+ "-d", "--eval_data_path", type=str, required=True, help="Path to eval data"
618
+ )
619
+ group.add_argument(
620
+ "--debug", action="store_true", default=False, help="Print infos."
621
+ )
622
+ group.add_argument(
623
+ "--overwrite",
624
+ action="store_true",
625
+ default=False,
626
+ help="Overwrite existed results",
627
+ )
628
+
629
+ args = parser.parse_args()
630
+ set_seed(args.seed)
631
+
632
+ main(args)
eval/evaluate_chatml_gsm8k.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from pathlib import Path
4
+ import argparse
5
+ import numpy as np
6
+ import tqdm
7
+ from datasets import load_from_disk, load_dataset
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from transformers.generation import GenerationConfig
10
+
11
+ import os
12
+ import pandas as pd
13
+ import numpy as np
14
+ import argparse
15
+ import datasets
16
+ import torch
17
+ import re
18
+ from thefuzz import process
19
+ from typing import List
20
+ from tqdm import tqdm
21
+ from transformers.trainer_utils import set_seed
22
+
23
+ from typing import Tuple, List, Union, Iterable
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from transformers import PreTrainedTokenizer
29
+ from transformers import logging
30
+ from transformers.generation import LogitsProcessor
31
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
32
+ HistoryType = List[Tuple[str, str]]
33
+ TokensType = List[int]
34
+ BatchTokensType = List[List[int]]
35
+
36
+ def make_context(
37
+ tokenizer: PreTrainedTokenizer,
38
+ query: str,
39
+ history: List[Tuple[str, str]] = None,
40
+ system: str = "",
41
+ max_window_size: int = 6144,
42
+ chat_format: str = "chatml",
43
+ ):
44
+ if history is None:
45
+ history = []
46
+
47
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
48
+ im_start_tokens = [tokenizer.im_start_id]
49
+ im_end_tokens = [tokenizer.im_end_id]
50
+ nl_tokens = tokenizer.encode("\n")
51
+
52
+ def _tokenize_str(role, content):
53
+ return f"{role}\n{content}", tokenizer.encode(
54
+ role
55
+ ) + nl_tokens + tokenizer.encode(content)
56
+
57
+ system_text, system_tokens_part = _tokenize_str("system", system)
58
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
59
+
60
+ raw_text = ""
61
+ context_tokens = []
62
+
63
+ for turn_query, turn_response in reversed(history):
64
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
65
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
66
+ response_text, response_tokens_part = _tokenize_str(
67
+ "assistant", turn_response
68
+ )
69
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
70
+
71
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
72
+ prev_chat = (
73
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
74
+ )
75
+
76
+ current_context_size = (
77
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
78
+ )
79
+ if current_context_size < max_window_size:
80
+ context_tokens = next_context_tokens + context_tokens
81
+ raw_text = prev_chat + raw_text
82
+ else:
83
+ break
84
+
85
+ context_tokens = system_tokens + context_tokens
86
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
87
+ context_tokens += (
88
+ nl_tokens
89
+ + im_start_tokens
90
+ + _tokenize_str("user", query)[1]
91
+ + im_end_tokens
92
+ + nl_tokens
93
+ + im_start_tokens
94
+ + tokenizer.encode("assistant")
95
+ + nl_tokens
96
+ )
97
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
98
+
99
+ return raw_text, context_tokens
100
+
101
+ def chat(
102
+ model,
103
+ tokenizer: PreTrainedTokenizer,
104
+ query: str,
105
+ history: Optional[HistoryType],
106
+ system: str = "You are a helpful assistant.",
107
+ append_history: bool = True
108
+ ) -> Tuple[str, HistoryType]:
109
+
110
+
111
+ if history is None:
112
+ history = []
113
+
114
+ raw_text, context_tokens = make_context(
115
+ tokenizer,
116
+ query,
117
+ history=history,
118
+ system=system,
119
+ max_window_size=6144,
120
+ chat_format = "chatml",
121
+ )
122
+
123
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
124
+ input_ids = torch.tensor([context_tokens]).cuda()
125
+ outputs = model.generate(
126
+ input_ids,
127
+ # stop_words_ids = stop_words_ids,
128
+ return_dict_in_generate = False,
129
+ )
130
+
131
+ response = decode_tokens(
132
+ outputs[0],
133
+ tokenizer,
134
+ raw_text_len=len(raw_text),
135
+ context_length=len(context_tokens),
136
+ chat_format='chatml',
137
+ verbose=False,
138
+ )
139
+
140
+ if append_history:
141
+ history.append((query, response))
142
+
143
+ return response, history
144
+
145
+ def decode_tokens(
146
+ tokens: Union[torch.LongTensor, TokensType],
147
+ tokenizer: PreTrainedTokenizer,
148
+ raw_text_len: int,
149
+ context_length: int,
150
+ chat_format: str = "chatml",
151
+ verbose: bool = False,
152
+ return_end_reason: bool = False,
153
+ ) -> str:
154
+ if torch.is_tensor(tokens):
155
+ tokens = tokens.cpu().numpy().tolist()
156
+
157
+
158
+ return _decode_chatml(
159
+ tokens,
160
+ stop_words=[],
161
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
162
+ tokenizer=tokenizer,
163
+ raw_text_len=raw_text_len,
164
+ context_length=context_length,
165
+ verbose=verbose,
166
+ return_end_reason=return_end_reason,
167
+ )
168
+
169
+
170
+ def _decode_chatml(
171
+ tokens: List[int],
172
+ *,
173
+ stop_words: List[str],
174
+ eod_token_ids: List[int],
175
+ tokenizer: PreTrainedTokenizer,
176
+ raw_text_len: int,
177
+ context_length: int,
178
+ verbose: bool = False,
179
+ return_end_reason: bool = False,
180
+ chat_format = "chatml",
181
+ ):
182
+ end_reason = f"Gen length {len(tokens)}"
183
+ eod_token_idx = context_length
184
+ for eod_token_idx in range(context_length, len(tokens)):
185
+ if tokens[eod_token_idx] in eod_token_ids:
186
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
187
+ break
188
+
189
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx])[raw_text_len:]
190
+ if verbose:
191
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens)[raw_text_len:])
192
+ print("\nRaw Generate:", trim_decode_tokens)
193
+ print("\nEnd Reason:", end_reason)
194
+ for stop_word in stop_words:
195
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
196
+ trim_decode_tokens = trim_decode_tokens.strip()
197
+ if verbose:
198
+ print("\nGenerate:", trim_decode_tokens)
199
+
200
+ if return_end_reason:
201
+ return trim_decode_tokens, end_reason
202
+ else:
203
+ return trim_decode_tokens
204
+
205
+
206
+
207
+ def load_models_tokenizer(args):
208
+ from transformers import AutoModelForCausalLM, AutoTokenizer
209
+ from transformers.generation import GenerationConfig
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
212
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
213
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
214
+ model.generation_config.do_sample = False # use greedy decoding
215
+ return model, tokenizer
216
+
217
+ '''
218
+ python eval/evaluate_chat_gsm8k.py [--use-fewshot]
219
+ '''
220
+
221
+ INVALID_ANS = "[invalid]"
222
+ DEVICE = "cuda:0"
223
+
224
+ def doc_to_text(doc, use_fewshot):
225
+ if use_fewshot:
226
+ context = (
227
+ "Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\nLet's think step by step\n"
228
+ "Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\nFor the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\nAngelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\nHowever, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\nThey also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\nAnd they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\nSo Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\nThey want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\nThey will need to plan to study 4 days to allow for all the time they need.\nThe answer is 4\n\n"
229
+ "Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?\nLet's think step by step\n"
230
+ "Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\nHis team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\nThey scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\nAll together his team scored 50+24+10= 84 points\nMark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\nHis opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\nThey also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\nAll together Mark's opponents scored 100+12+5=117 points\nThe total score for the game is both team's scores added together, so it is 84+117=201 points\nThe answer is 201\n\n"
231
+ "Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\nLet's think step by step\n"
232
+ "When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\nThe total number of marbles she'll have is 60+24 = 84\nIf Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\nIf Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\nThe total number of frisbees she'll have will increase to 30+12 = 42\nBella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\nIf she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\nThe total number of deck cards she'll have is 10+4 = 14\nTogether, Bella will have a total of 14+42+84 = 140 items\nThe answer is 140\n\n"
233
+ "Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\nLet's think step by step\n"
234
+ "For the first three baskets, the number of apples and oranges in one basket is 9+15=24\nIn total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\nSince there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\nThe number of apples in the fourth basket is 9-2=7\nThere are also 15-2=13 oranges in the fourth basket\nThe combined number of oranges and apples in the fourth basket is 13+7=20\nThe fourth basket also contains 14-2=12 bananas.\nIn total, the fourth basket has 20+12=32 fruits.\nThe four baskets together have 32+114=146 fruits.\nThe answer is 146\n\n"
235
+ f"Question: {doc['question']}\nLet's think step by step"
236
+ )
237
+ else:
238
+ context = doc["question"]
239
+ return context
240
+
241
+
242
+ def decode(tokens_list, tokenizer, raw_text_len):
243
+ sents = []
244
+ for tokens in tokens_list:
245
+ tokens = tokens.cpu().numpy().tolist()
246
+ sent = tokenizer.tokenizer.decode(tokens[raw_text_len:])
247
+ sent = sent.split("<|endoftext|>")[0]
248
+ sent = sent.split("\n\n\n")[0]
249
+ sent = sent.split("\n\n")[0]
250
+ sent = sent.split("Question:")[0]
251
+ sents.append(sent)
252
+ return sents
253
+
254
+
255
+ def generate_sample(model, tokenizer, question):
256
+ response, _ = chat(
257
+ model,
258
+ tokenizer,
259
+ question,
260
+ history=None,
261
+ )
262
+ print(question)
263
+ print("-------------")
264
+ print(response)
265
+ print("=============")
266
+ return response
267
+
268
+
269
+ def extract_answer_hf(completion):
270
+ def _get_last_digit(s):
271
+ _PAT_LAST_DIGIT = re.compile(
272
+ r"(?<=(\s|[\$%#{]))([+-])?(?=(\S))(0|([1-9](\d*|\d{0,2}(,\d{3})*)))?(\.\d*[1-9])?(?=(\s|[.,}]|$))"
273
+ )
274
+ match = list(_PAT_LAST_DIGIT.finditer(s))
275
+ if match:
276
+ last_digit = match[-1].group().replace(",", "").replace("+", "")
277
+ # print(f"The last digit in {s} is {last_digit}")
278
+ else:
279
+ last_digit = None
280
+ print(f"No digits found in {s!r}")
281
+ return last_digit
282
+
283
+ job_gen = completion.strip(".").replace("\n", "\\n")
284
+ last_digit = _get_last_digit(job_gen)
285
+ if last_digit is not None:
286
+ return eval(last_digit)
287
+ return INVALID_ANS
288
+
289
+
290
+ def extract_answer(completion):
291
+ try:
292
+ last_number = re.findall(r"\d+", completion)[-1]
293
+ return eval(last_number)
294
+ except:
295
+ return INVALID_ANS
296
+
297
+
298
+ def is_correct(completion, answer):
299
+ gold = extract_answer(answer)
300
+ assert gold != INVALID_ANS, "No ground truth answer found in the document."
301
+ return extract_answer(completion) == gold
302
+
303
+
304
+ if __name__ == "__main__":
305
+ parser = argparse.ArgumentParser(description="Test HF checkpoint.")
306
+ parser.add_argument(
307
+ "-c",
308
+ "--checkpoint-path",
309
+ type=Path,
310
+ help="Checkpoint path",
311
+ default="Qwen/Qwen-7B-Chat",
312
+ )
313
+ parser.add_argument("-f", "--sample-input-file", type=str, default=None)
314
+ parser.add_argument(
315
+ "-o", "--sample-output-file", type=str, default="gsm8k_res.jsonl"
316
+ )
317
+ parser.add_argument("--use-fewshot", action="store_true")
318
+
319
+ args = parser.parse_args()
320
+
321
+ if args.sample_input_file is not None:
322
+ dataset = load_from_disk(args.sample_input_file) # or:
323
+ else:
324
+ dataset = load_dataset("gsm8k", "main")
325
+
326
+ print("Loading tokenizer ...")
327
+ tokenizer = AutoTokenizer.from_pretrained(
328
+ args.checkpoint_path, trust_remote_code=True, bf16=True, use_flash_attn=True
329
+ )
330
+
331
+ print("Loading model ...")
332
+ model = AutoModelForCausalLM.from_pretrained(
333
+ args.checkpoint_path, device_map="auto", trust_remote_code=True
334
+ ).eval()
335
+ model.generation_config = GenerationConfig.from_pretrained(
336
+ args.checkpoint_path, trust_remote_code=True
337
+ )
338
+ model.generation_config.do_sample = False # use greedy decoding
339
+
340
+ test = dataset["test"]
341
+
342
+ f_output = open(args.sample_output_file, "w", encoding="utf-8")
343
+ tot_length = test.num_rows
344
+ acc_res = []
345
+ for doc in tqdm(test):
346
+ context = doc_to_text(doc, args.use_fewshot)
347
+ print(context)
348
+ completion = generate_sample(model, tokenizer, context)
349
+ answer = doc["answer"]
350
+ acc = is_correct(completion, answer)
351
+ doc["completion"] = completion
352
+ doc["acc"] = acc
353
+ f_output.write(json.dumps(doc, ensure_ascii=False) + "\n")
354
+ f_output.flush()
355
+ acc_res.append(acc)
356
+
357
+ f_output.close()
358
+ print("4-shot Acc: " if args.use_fewshot else "Zero-shot Acc", np.mean(acc_res))
eval/evaluate_chatml_mmlu.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import argparse
5
+ import datasets
6
+ import torch
7
+ import re
8
+ from thefuzz import process
9
+ from typing import List
10
+ from tqdm import tqdm
11
+ from transformers.trainer_utils import set_seed
12
+
13
+ from typing import Tuple, List, Union, Iterable
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from transformers import PreTrainedTokenizer
19
+ from transformers import logging
20
+ from transformers.generation import LogitsProcessor
21
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
22
+ HistoryType = List[Tuple[str, str]]
23
+ TokensType = List[int]
24
+ BatchTokensType = List[List[int]]
25
+
26
+ def make_context(
27
+ tokenizer: PreTrainedTokenizer,
28
+ query: str,
29
+ history: List[Tuple[str, str]] = None,
30
+ system: str = "",
31
+ max_window_size: int = 6144,
32
+ chat_format: str = "chatml",
33
+ ):
34
+ if history is None:
35
+ history = []
36
+
37
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
38
+ im_start_tokens = [tokenizer.im_start_id]
39
+ im_end_tokens = [tokenizer.im_end_id]
40
+ nl_tokens = tokenizer.encode("\n")
41
+
42
+ def _tokenize_str(role, content):
43
+ return f"{role}\n{content}", tokenizer.encode(
44
+ role
45
+ ) + nl_tokens + tokenizer.encode(content)
46
+
47
+ system_text, system_tokens_part = _tokenize_str("system", system)
48
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
49
+
50
+ raw_text = ""
51
+ context_tokens = []
52
+
53
+ for turn_query, turn_response in reversed(history):
54
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
55
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
56
+ response_text, response_tokens_part = _tokenize_str(
57
+ "assistant", turn_response
58
+ )
59
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
60
+
61
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
62
+ prev_chat = (
63
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
64
+ )
65
+
66
+ current_context_size = (
67
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
68
+ )
69
+ if current_context_size < max_window_size:
70
+ context_tokens = next_context_tokens + context_tokens
71
+ raw_text = prev_chat + raw_text
72
+ else:
73
+ break
74
+
75
+ context_tokens = system_tokens + context_tokens
76
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
77
+ context_tokens += (
78
+ nl_tokens
79
+ + im_start_tokens
80
+ + _tokenize_str("user", query)[1]
81
+ + im_end_tokens
82
+ + nl_tokens
83
+ + im_start_tokens
84
+ + tokenizer.encode("assistant")
85
+ + nl_tokens
86
+ )
87
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
88
+
89
+ return raw_text, context_tokens
90
+
91
+ def chat(
92
+ model,
93
+ tokenizer: PreTrainedTokenizer,
94
+ query: str,
95
+ history: Optional[HistoryType],
96
+ system: str = "You are a helpful assistant.",
97
+ append_history: bool = True
98
+ ) -> Tuple[str, HistoryType]:
99
+
100
+
101
+ if history is None:
102
+ history = []
103
+
104
+ raw_text, context_tokens = make_context(
105
+ tokenizer,
106
+ query,
107
+ history=history,
108
+ system=system,
109
+ max_window_size=6144,
110
+ chat_format = "chatml",
111
+ )
112
+
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ input_ids = torch.tensor([context_tokens]).cuda()
115
+ outputs = model.generate(
116
+ input_ids,
117
+ # stop_words_ids = stop_words_ids,
118
+ return_dict_in_generate = False,
119
+ )
120
+
121
+ response = decode_tokens(
122
+ outputs[0],
123
+ tokenizer,
124
+ raw_text_len=len(raw_text),
125
+ context_length=len(context_tokens),
126
+ chat_format='chatml',
127
+ verbose=False,
128
+ )
129
+
130
+ if append_history:
131
+ history.append((query, response))
132
+
133
+ return response, history
134
+
135
+ def decode_tokens(
136
+ tokens: Union[torch.LongTensor, TokensType],
137
+ tokenizer: PreTrainedTokenizer,
138
+ raw_text_len: int,
139
+ context_length: int,
140
+ chat_format: str = "chatml",
141
+ verbose: bool = False,
142
+ return_end_reason: bool = False,
143
+ ) -> str:
144
+ if torch.is_tensor(tokens):
145
+ tokens = tokens.cpu().numpy().tolist()
146
+
147
+
148
+ return _decode_chatml(
149
+ tokens,
150
+ stop_words=[],
151
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
152
+ tokenizer=tokenizer,
153
+ raw_text_len=raw_text_len,
154
+ context_length=context_length,
155
+ verbose=verbose,
156
+ return_end_reason=return_end_reason,
157
+ )
158
+
159
+
160
+ def _decode_chatml(
161
+ tokens: List[int],
162
+ *,
163
+ stop_words: List[str],
164
+ eod_token_ids: List[int],
165
+ tokenizer: PreTrainedTokenizer,
166
+ raw_text_len: int,
167
+ context_length: int,
168
+ verbose: bool = False,
169
+ return_end_reason: bool = False,
170
+ chat_format = "chatml",
171
+ ):
172
+ end_reason = f"Gen length {len(tokens)}"
173
+ eod_token_idx = context_length
174
+ for eod_token_idx in range(context_length, len(tokens)):
175
+ if tokens[eod_token_idx] in eod_token_ids:
176
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
177
+ break
178
+
179
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx])[raw_text_len:]
180
+ if verbose:
181
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens)[raw_text_len:])
182
+ print("\nRaw Generate:", trim_decode_tokens)
183
+ print("\nEnd Reason:", end_reason)
184
+ for stop_word in stop_words:
185
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
186
+ trim_decode_tokens = trim_decode_tokens.strip()
187
+ if verbose:
188
+ print("\nGenerate:", trim_decode_tokens)
189
+
190
+ if return_end_reason:
191
+ return trim_decode_tokens, end_reason
192
+ else:
193
+ return trim_decode_tokens
194
+
195
+
196
+
197
+ def load_models_tokenizer(args):
198
+ from transformers import AutoModelForCausalLM, AutoTokenizer
199
+ from transformers.generation import GenerationConfig
200
+
201
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path, trust_remote_code=True)
202
+ model = AutoModelForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto", trust_remote_code=True).eval()
203
+ model.generation_config = GenerationConfig.from_pretrained(args.checkpoint_path, trust_remote_code=True)
204
+ model.generation_config.do_sample = False # use greedy decoding
205
+ return model, tokenizer
206
+
207
+
208
+ def format_example(line):
209
+ example = 'The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n' + line['question'] + "\n"
210
+ for choice in choices:
211
+ example += f'{choice}. {line[f"{choice}"]}\n'
212
+ return example
213
+
214
+
215
+ def process_before_extraction(gen, choice_dict):
216
+ # replace the choice by letter in the generated sentence
217
+ # from longest one to shortest one
218
+ for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
219
+ pattern = re.compile(re.escape(val.rstrip(".")), re.IGNORECASE)
220
+ gen = pattern.sub(key, gen)
221
+ return gen
222
+
223
+ def extract_choice(gen, choice_list):
224
+ # answer is A | choice is A | choose A
225
+ res = re.search(r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b", gen)
226
+
227
+ # A is correct | A is right
228
+ if res is None:
229
+ res = re.search(r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b", gen)
230
+
231
+ # straight answer: A
232
+ if res is None:
233
+ res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
234
+
235
+ # simply extract the first appearred letter
236
+ if res is None:
237
+ res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
238
+
239
+ if res is None:
240
+ return choices[choice_list.index(process.extractOne(gen, choice_list)[0])]
241
+ else:
242
+ return res.group(1)
243
+
244
+ def extract_answer(response, row):
245
+ gen = process_before_extraction(response, {choice: row[choice] for choice in choices})
246
+ pred = extract_choice(gen, [row[choice] for choice in choices])
247
+ return pred
248
+
249
+ @torch.no_grad()
250
+ def eval_subject(
251
+ model,
252
+ tokenizer,
253
+ subject_name,
254
+ test_df,
255
+ save_result_dir=None,
256
+ overwrite=False,
257
+ **kwargs
258
+ ):
259
+ result_path = os.path.join(save_result_dir, f'{subject_name}_result.csv')
260
+ if not overwrite and os.path.exists(result_path):
261
+ print(f"{result_path} existed, skip!")
262
+ score = []
263
+ for (_, datarow), (_, resultrow) in zip(test_df.iterrows(), pd.read_csv(result_path).iterrows()):
264
+ # pred = extract_answer(resultrow['model_response'], datarow)
265
+ pred = resultrow['model_output']
266
+ correct = 1 if pred == datarow['answer'] else 0
267
+ score.append(correct)
268
+ return score
269
+
270
+ result = []
271
+ score = []
272
+
273
+ for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
274
+ question = format_example(row)
275
+
276
+ response, history = chat(
277
+ model,
278
+ tokenizer,
279
+ question,
280
+ history=None,
281
+ )
282
+ print(question)
283
+ print(response)
284
+ pred = extract_answer(response, row)
285
+ print(pred)
286
+ print("======================")
287
+
288
+ if 'answer' in row:
289
+ correct = 1 if pred == row['answer'] else 0
290
+ score.append(correct)
291
+ if args.debug: print(f'{question} pred: {pred} ref: {row["answer"]}')
292
+ result.append(pred)
293
+
294
+ if save_result_dir:
295
+ test_df['model_output'] = result
296
+ test_df['model_response'] = response
297
+ if score:
298
+ test_df["correctness"] = score
299
+ os.makedirs(save_result_dir, exist_ok=True)
300
+ test_df.to_csv(os.path.join(
301
+ save_result_dir, f'{subject_name}_result.csv'), encoding="utf-8", index=False)
302
+
303
+ return score
304
+
305
+
306
+ def cal_mmlu(res):
307
+ acc_sum_dict = dict()
308
+ acc_norm_sum_dict = dict()
309
+ cnt_dict = dict()
310
+ acc_sum = 0.
311
+ cnt = 0
312
+ hard_cnt = 0
313
+ hard_acc_sum = 0.
314
+
315
+ for class_ in TASK_NAME_MAPPING.keys():
316
+ acc_sum_dict[class_] = 0.
317
+ acc_norm_sum_dict[class_] = 0.
318
+ cnt_dict[class_] = 0.
319
+
320
+ for tt in TASK_NAME_MAPPING[class_]:
321
+ acc_sum += sum(res[tt])
322
+ cnt += len(res[tt])
323
+
324
+ acc_sum_dict[class_] += sum(res[tt])
325
+ cnt_dict[class_] += len(res[tt])
326
+
327
+ print('\n\n\n')
328
+ for k in TASK_NAME_MAPPING.keys():
329
+ if k in cnt_dict:
330
+ print('%s ACC: %.2f ' % (
331
+ k, acc_sum_dict[k] * 100 / cnt_dict[k]))
332
+ print('AVERAGE ACC:%.2f ' % (acc_sum *100 / cnt))
333
+
334
+
335
+ def main(args):
336
+ print("loading model weights")
337
+ if args.checkpoint_path is not None:
338
+ model, tokenizer = load_models_tokenizer(args)
339
+ else:
340
+ model, tokenizer = None, None
341
+ print("model loaded")
342
+
343
+ dev_result = {}
344
+ for subject_name in tqdm(SUBJECTS):
345
+ # val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
346
+ # dev_file_path = os.path.join(args.eval_data_path, 'dev', f'{subject_name}_dev.csv')
347
+ test_file_path = os.path.join(args.eval_data_path, 'test', f'{subject_name}_test.csv')
348
+ # val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
349
+ # dev_df = pd.read_csv(dev_file_path, names=['question','A','B','C','D','answer'])
350
+ test_df = pd.read_csv(test_file_path, names=['question','A','B','C','D','answer'])
351
+
352
+ score = eval_subject(model, tokenizer, subject_name, test_df, save_result_dir=f"outs_chat/mmlu_eval_result", overwrite=args.overwrite)
353
+ dev_result[subject_name] = score
354
+ cal_mmlu(dev_result)
355
+
356
+
357
+ TASK_NAME_MAPPING = {'stem': ['abstract_algebra', 'anatomy', 'astronomy', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 'conceptual_physics', 'electrical_engineering', 'elementary_mathematics', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_mathematics', 'high_school_physics', 'high_school_statistics', 'machine_learning'],
358
+ 'Humanities': ['formal_logic', 'high_school_european_history', 'high_school_us_history', 'high_school_world_history', 'international_law', 'jurisprudence', 'logical_fallacies', 'moral_disputes', 'moral_scenarios', 'philosophy', 'prehistory', 'professional_law', 'world_religions'],
359
+ 'other': ['business_ethics', 'college_medicine', 'human_aging', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'nutrition', 'professional_accounting', 'professional_medicine', 'virology', 'global_facts', 'clinical_knowledge'],
360
+ 'social': ['econometrics', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_microeconomics', 'high_school_psychology', 'human_sexuality', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy']}
361
+ SUBJECTS = [v for vl in TASK_NAME_MAPPING.values() for v in vl]
362
+ choices = ["A", "B", "C", "D"]
363
+
364
+ if __name__ == '__main__':
365
+ parser = argparse.ArgumentParser(description='Test HF checkpoint.')
366
+ parser.add_argument('-c', '--checkpoint-path', type=str, help='Checkpoint path', default="Qwen/Qwen-7B-Chat")
367
+ parser.add_argument('-s', '--seed', type=int, default=1234, help='Random seed')
368
+
369
+ """Provide extra arguments required for tasks."""
370
+ group = parser.add_argument_group(title='Evaluation options')
371
+ group.add_argument('-d', '--eval_data_path', type=str,
372
+ help='Path to eval data')
373
+ group.add_argument("--debug", action='store_true', default=False,
374
+ help='Print infos.')
375
+ group.add_argument("--overwrite", action='store_true', default=False,
376
+ help='Overwrite existed results')
377
+
378
+ args = parser.parse_args()
379
+ set_seed(args.seed)
380
+
381
+ main(args)