thomwolf HF staff commited on
Commit
f1d3dc6
1 Parent(s): 9d018f5

add eval code

Browse files
custom_evaluation_tasks.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: F405, F403, F401
2
+ """
3
+ Custom evaluation tasks for lighteval
4
+
5
+ This file generally create just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval.
6
+ """
7
+ import re
8
+ from dataclasses import asdict
9
+ from typing import Dict, List, Tuple
10
+
11
+ from custom_evaluation_utils import *
12
+ from lighteval.tasks.requests import Doc
13
+
14
+ # fmt: off
15
+ LETTER_INDICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
16
+ # fmt: on
17
+
18
+ _TASKS_STRINGS: List[Tuple[CustomEvaluationTask, str]] = []
19
+ _TASKS: List[CustomEvaluationTask] = []
20
+
21
+ ## COMMON_SENSE_REASONING_TASKS ##
22
+ COMMON_SENSE_REASONING_TASKS = [
23
+ CustomEvaluationTask(
24
+ name="hellaswag",
25
+ prompt_function="hellaswag_prompt",
26
+ hf_repo="hellaswag",
27
+ hf_subset="default",
28
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
29
+ ),
30
+ CustomEvaluationTask(
31
+ name="winogrande",
32
+ prompt_function="winogrande",
33
+ hf_repo="winogrande",
34
+ hf_subset="winogrande_xl",
35
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
36
+ ),
37
+ CustomEvaluationTask(
38
+ name="piqa",
39
+ prompt_function="piqa_harness",
40
+ hf_repo="piqa",
41
+ hf_subset="plain_text",
42
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
43
+ ),
44
+ CustomEvaluationTask(
45
+ name="siqa",
46
+ prompt_function="siqa_prompt",
47
+ hf_repo="lighteval/siqa",
48
+ hf_subset="default",
49
+ hf_avail_splits=["train", "validation"],
50
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
51
+ ),
52
+ CustomEvaluationTask(
53
+ name="openbookqa",
54
+ prompt_function="openbookqa",
55
+ hf_repo="openbookqa",
56
+ hf_subset="main",
57
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
58
+ ),
59
+ CustomEvaluationTask(
60
+ name="arc:easy",
61
+ prompt_function="arc",
62
+ hf_repo="ai2_arc",
63
+ hf_subset="ARC-Easy",
64
+ evaluation_splits=["test"],
65
+ generation_size=1,
66
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
67
+ ),
68
+ CustomEvaluationTask(
69
+ name="arc:challenge",
70
+ prompt_function="arc",
71
+ hf_repo="ai2_arc",
72
+ hf_subset="ARC-Challenge",
73
+ evaluation_splits=["test"],
74
+ generation_size=1,
75
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
76
+ ),
77
+ CustomEvaluationTask(
78
+ name="commonsense_qa",
79
+ prompt_function="commonsense_qa_prompt",
80
+ hf_repo="commonsense_qa",
81
+ hf_subset="default",
82
+ metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
83
+ ),
84
+ ]
85
+
86
+
87
+ def commonsense_qa_prompt(line, task_name: str = None):
88
+ return Doc(
89
+ task_name=task_name,
90
+ query=line["question"],
91
+ choices=[f" {c}" for c in line["choices"]["text"]],
92
+ gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
93
+ instruction="",
94
+ )
95
+
96
+
97
+ def siqa_prompt(line, task_name: str = None):
98
+ return Doc(
99
+ task_name=task_name,
100
+ query=line["context"] + " " + line["question"],
101
+ choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]],
102
+ gold_index=int(line["label"]) - 1,
103
+ instruction="",
104
+ )
105
+
106
+
107
+ def hellaswag_prompt(line, task_name: str = None):
108
+ def preprocess(text):
109
+ """Comes from AiHarness"""
110
+ # text = text.strip()
111
+ # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
112
+ text = text.replace(" [title]", ". ")
113
+ text = re.sub("\\[.*?\\]", "", text)
114
+ text = text.replace(" ", " ")
115
+ return text
116
+
117
+ ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
118
+ return Doc(
119
+ task_name=task_name,
120
+ query=preprocess(line["activity_label"] + ": " + ctx),
121
+ choices=[" " + preprocess(ending) for ending in line["endings"]],
122
+ gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test
123
+ # "metric": "choices_loglikelihood",
124
+ )
125
+
126
+
127
+ # 0 short for common sense
128
+ COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS]
129
+ _TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING)
130
+ _TASKS += COMMON_SENSE_REASONING_TASKS
131
+
132
+ ## WORLD_KNOWLEDGE_TASKS ##
133
+
134
+ WORLD_KNOWLEDGE_TASKS = [
135
+ CustomEvaluationTask(
136
+ name="trivia_qa",
137
+ prompt_function="triviaqa",
138
+ hf_repo="trivia_qa",
139
+ hf_subset="rc.nocontext",
140
+ metric=[Metrics.quasi_exact_match2],
141
+ generation_size=20,
142
+ stop_sequence=["\n", ".", ","],
143
+ ),
144
+ CustomEvaluationTask(
145
+ name="natural_questions",
146
+ prompt_function="natural_questions_prompt",
147
+ hf_repo="lighteval/natural_questions_clean",
148
+ hf_subset="default",
149
+ metric=[Metrics.quasi_exact_match2],
150
+ generation_size=20,
151
+ stop_sequence=["\n", ".", ","],
152
+ ),
153
+ ]
154
+
155
+
156
+ def natural_questions_prompt(line, task_name: str = None):
157
+ return Doc(
158
+ task_name=task_name,
159
+ query=line["question"] + "?\nAnswer: ",
160
+ choices=[line["short_answers"]],
161
+ gold_index=0,
162
+ instruction="",
163
+ )
164
+
165
+
166
+ WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS]
167
+ # WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS}
168
+ _TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING)
169
+ _TASKS += WORLD_KNOWLEDGE_TASKS
170
+
171
+ ## Reading comprehension ##
172
+
173
+ READING_COMP_TASKS = [
174
+ CustomEvaluationTask(
175
+ name="super_glue:boolq",
176
+ prompt_function="boolq_prompt",
177
+ hf_repo="super_glue",
178
+ hf_subset="boolq",
179
+ metric=[Metrics.target_perplexity],
180
+ ),
181
+ CustomEvaluationTask(
182
+ name="quac",
183
+ prompt_function="quac",
184
+ hf_repo="lighteval/quac_helm",
185
+ hf_subset="default",
186
+ metric=[Metrics.quasi_exact_match2],
187
+ generation_size=20,
188
+ stop_sequence=["\n", ".", ","],
189
+ ),
190
+ ]
191
+
192
+
193
+ def boolq_prompt(line, task_name: str = None):
194
+ return Doc(
195
+ task_name=task_name,
196
+ query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:",
197
+ choices=[" No", " Yes"], # Only gold
198
+ gold_index=int(line["label"]),
199
+ )
200
+
201
+
202
+ READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS]
203
+ _TASKS_STRINGS.extend(READING_COMP_STRING)
204
+ _TASKS += READING_COMP_TASKS
205
+
206
+
207
+ ## MATH ##
208
+ class CustomMathEvaluationTask(CustomEvaluationTask):
209
+ """Custom class for math tasks with all the defaults set"""
210
+
211
+ def __init__(
212
+ self,
213
+ name,
214
+ prompt_function="math",
215
+ hf_repo="lighteval/MATH",
216
+ hf_subset=None,
217
+ metric=[Metrics.math_quasi_exact_match],
218
+ hf_avail_splits=None,
219
+ evaluation_splits=["test"],
220
+ few_shots_split=None,
221
+ few_shots_select=None,
222
+ suite=["custom"],
223
+ generation_size=40,
224
+ stop_sequence=None,
225
+ output_regex=None,
226
+ frozen=False,
227
+ ):
228
+ super().__init__(
229
+ name=name,
230
+ prompt_function=prompt_function,
231
+ hf_repo=hf_repo,
232
+ hf_subset=hf_subset,
233
+ metric=metric,
234
+ hf_avail_splits=hf_avail_splits,
235
+ evaluation_splits=evaluation_splits,
236
+ few_shots_split=few_shots_split,
237
+ few_shots_select=few_shots_select,
238
+ suite=suite,
239
+ generation_size=generation_size,
240
+ stop_sequence=stop_sequence,
241
+ output_regex=output_regex,
242
+ frozen=frozen,
243
+ )
244
+
245
+
246
+ MATH_TASKS = [
247
+ CustomMathEvaluationTask(name="math:algebra", hf_subset="algebra"),
248
+ CustomMathEvaluationTask(name="math:counting_and_probability", hf_subset="counting_and_probability"),
249
+ CustomMathEvaluationTask(name="math:geometry", hf_subset="geometry"),
250
+ CustomMathEvaluationTask(name="math:intermediate_algebra", hf_subset="intermediate_algebra"),
251
+ CustomMathEvaluationTask(name="math:number_theory", hf_subset="number_theory"),
252
+ CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"),
253
+ CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"),
254
+ ]
255
+ GSM8K = CustomEvaluationTask(
256
+ name="gsm8k",
257
+ prompt_function="gsm8k",
258
+ hf_repo="gsm8k",
259
+ hf_subset="main",
260
+ hf_avail_splits=["train", "test"],
261
+ evaluation_splits=["test"],
262
+ metric=[Metrics.perfect_exact_match],
263
+ generation_size=10,
264
+ stop_sequence=["\n"],
265
+ )
266
+
267
+
268
+ MATH_STRING = [(t, f"custom|{t.name}|4|1") for t in MATH_TASKS]
269
+ GSM8K_STRING = [(GSM8K, f"custom|{GSM8K.name}|8|1")]
270
+ _TASKS_STRINGS.extend(MATH_STRING)
271
+ _TASKS_STRINGS.extend(GSM8K_STRING)
272
+ _TASKS += MATH_TASKS + [GSM8K]
273
+
274
+
275
+ ## MMLU ##
276
+ class CustomMMLUEvaluationTask(CustomEvaluationTask):
277
+ def __init__(
278
+ self,
279
+ name,
280
+ prompt_function="mmlu_prompt",
281
+ hf_repo="lighteval/mmlu",
282
+ hf_subset=None,
283
+ # metric=[Metrics.loglikelihood_acc_single_token],
284
+ metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
285
+ hf_avail_splits=None,
286
+ evaluation_splits=["test"],
287
+ few_shots_split="dev",
288
+ few_shots_select=None,
289
+ suite=None,
290
+ generation_size=-1,
291
+ stop_sequence=None,
292
+ output_regex=None,
293
+ frozen=False,
294
+ ):
295
+ super().__init__(
296
+ name=name,
297
+ prompt_function=prompt_function,
298
+ hf_repo=hf_repo,
299
+ hf_subset=hf_subset,
300
+ metric=metric,
301
+ hf_avail_splits=hf_avail_splits,
302
+ evaluation_splits=evaluation_splits,
303
+ few_shots_split=few_shots_split,
304
+ few_shots_select=few_shots_select,
305
+ suite=suite,
306
+ generation_size=generation_size,
307
+ stop_sequence=stop_sequence,
308
+ output_regex=output_regex,
309
+ frozen=frozen,
310
+ )
311
+
312
+
313
+ MMLU_TASKS = [
314
+ CustomMMLUEvaluationTask(name="mmlu:abstract_algebra", hf_subset="abstract_algebra"),
315
+ CustomMMLUEvaluationTask(name="mmlu:anatomy", hf_subset="anatomy"),
316
+ CustomMMLUEvaluationTask(name="mmlu:astronomy", hf_subset="astronomy"),
317
+ CustomMMLUEvaluationTask(name="mmlu:business_ethics", hf_subset="business_ethics"),
318
+ CustomMMLUEvaluationTask(name="mmlu:clinical_knowledge", hf_subset="clinical_knowledge"),
319
+ CustomMMLUEvaluationTask(name="mmlu:college_biology", hf_subset="college_biology"),
320
+ CustomMMLUEvaluationTask(name="mmlu:college_chemistry", hf_subset="college_chemistry"),
321
+ CustomMMLUEvaluationTask(name="mmlu:college_computer_science", hf_subset="college_computer_science"),
322
+ CustomMMLUEvaluationTask(name="mmlu:college_mathematics", hf_subset="college_mathematics"),
323
+ CustomMMLUEvaluationTask(name="mmlu:college_medicine", hf_subset="college_medicine"),
324
+ CustomMMLUEvaluationTask(name="mmlu:college_physics", hf_subset="college_physics"),
325
+ CustomMMLUEvaluationTask(name="mmlu:computer_security", hf_subset="computer_security"),
326
+ CustomMMLUEvaluationTask(name="mmlu:conceptual_physics", hf_subset="conceptual_physics"),
327
+ CustomMMLUEvaluationTask(name="mmlu:econometrics", hf_subset="econometrics"),
328
+ CustomMMLUEvaluationTask(name="mmlu:electrical_engineering", hf_subset="electrical_engineering"),
329
+ CustomMMLUEvaluationTask(name="mmlu:elementary_mathematics", hf_subset="elementary_mathematics"),
330
+ CustomMMLUEvaluationTask(name="mmlu:formal_logic", hf_subset="formal_logic"),
331
+ CustomMMLUEvaluationTask(name="mmlu:global_facts", hf_subset="global_facts"),
332
+ CustomMMLUEvaluationTask(name="mmlu:high_school_biology", hf_subset="high_school_biology"),
333
+ CustomMMLUEvaluationTask(name="mmlu:high_school_chemistry", hf_subset="high_school_chemistry"),
334
+ CustomMMLUEvaluationTask(name="mmlu:high_school_computer_science", hf_subset="high_school_computer_science"),
335
+ CustomMMLUEvaluationTask(name="mmlu:high_school_european_history", hf_subset="high_school_european_history"),
336
+ CustomMMLUEvaluationTask(name="mmlu:high_school_geography", hf_subset="high_school_geography"),
337
+ CustomMMLUEvaluationTask(
338
+ name="mmlu:high_school_government_and_politics", hf_subset="high_school_government_and_politics"
339
+ ),
340
+ CustomMMLUEvaluationTask(name="mmlu:high_school_macroeconomics", hf_subset="high_school_macroeconomics"),
341
+ CustomMMLUEvaluationTask(name="mmlu:high_school_mathematics", hf_subset="high_school_mathematics"),
342
+ CustomMMLUEvaluationTask(name="mmlu:high_school_microeconomics", hf_subset="high_school_microeconomics"),
343
+ CustomMMLUEvaluationTask(name="mmlu:high_school_physics", hf_subset="high_school_physics"),
344
+ CustomMMLUEvaluationTask(name="mmlu:high_school_psychology", hf_subset="high_school_psychology"),
345
+ CustomMMLUEvaluationTask(name="mmlu:high_school_statistics", hf_subset="high_school_statistics"),
346
+ CustomMMLUEvaluationTask(name="mmlu:high_school_us_history", hf_subset="high_school_us_history"),
347
+ CustomMMLUEvaluationTask(name="mmlu:high_school_world_history", hf_subset="high_school_world_history"),
348
+ CustomMMLUEvaluationTask(name="mmlu:human_aging", hf_subset="human_aging"),
349
+ CustomMMLUEvaluationTask(name="mmlu:human_sexuality", hf_subset="human_sexuality"),
350
+ CustomMMLUEvaluationTask(name="mmlu:international_law", hf_subset="international_law"),
351
+ CustomMMLUEvaluationTask(name="mmlu:jurisprudence", hf_subset="jurisprudence"),
352
+ CustomMMLUEvaluationTask(name="mmlu:logical_fallacies", hf_subset="logical_fallacies"),
353
+ CustomMMLUEvaluationTask(name="mmlu:machine_learning", hf_subset="machine_learning"),
354
+ CustomMMLUEvaluationTask(name="mmlu:management", hf_subset="management"),
355
+ CustomMMLUEvaluationTask(name="mmlu:marketing", hf_subset="marketing"),
356
+ CustomMMLUEvaluationTask(name="mmlu:medical_genetics", hf_subset="medical_genetics"),
357
+ CustomMMLUEvaluationTask(name="mmlu:miscellaneous", hf_subset="miscellaneous"),
358
+ CustomMMLUEvaluationTask(name="mmlu:moral_disputes", hf_subset="moral_disputes"),
359
+ CustomMMLUEvaluationTask(name="mmlu:moral_scenarios", hf_subset="moral_scenarios"),
360
+ CustomMMLUEvaluationTask(name="mmlu:nutrition", hf_subset="nutrition"),
361
+ CustomMMLUEvaluationTask(name="mmlu:philosophy", hf_subset="philosophy"),
362
+ CustomMMLUEvaluationTask(name="mmlu:prehistory", hf_subset="prehistory"),
363
+ CustomMMLUEvaluationTask(name="mmlu:professional_accounting", hf_subset="professional_accounting"),
364
+ CustomMMLUEvaluationTask(name="mmlu:professional_law", hf_subset="professional_law"),
365
+ CustomMMLUEvaluationTask(name="mmlu:professional_medicine", hf_subset="professional_medicine"),
366
+ CustomMMLUEvaluationTask(name="mmlu:professional_psychology", hf_subset="professional_psychology"),
367
+ CustomMMLUEvaluationTask(name="mmlu:public_relations", hf_subset="public_relations"),
368
+ CustomMMLUEvaluationTask(name="mmlu:security_studies", hf_subset="security_studies"),
369
+ CustomMMLUEvaluationTask(name="mmlu:sociology", hf_subset="sociology"),
370
+ CustomMMLUEvaluationTask(name="mmlu:us_foreign_policy", hf_subset="us_foreign_policy"),
371
+ CustomMMLUEvaluationTask(name="mmlu:virology", hf_subset="virology"),
372
+ CustomMMLUEvaluationTask(name="mmlu:world_religions", hf_subset="world_religions"),
373
+ ]
374
+
375
+
376
+ def mmlu_harness(line, task_name: str = None):
377
+ topic = line["subject"]
378
+ prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
379
+ prompt += line["question"] + "\n"
380
+ prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
381
+ prompt += "Answer:"
382
+
383
+ gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
384
+ "__few_shots" in line and line["__few_shots"] is True # We are adding few shots
385
+
386
+ return Doc(
387
+ task_name=task_name,
388
+ query=prompt,
389
+ choices=[" A", " B", " C", " D"],
390
+ target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
391
+ gold_index=gold_ix,
392
+ instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
393
+ )
394
+
395
+
396
+ def mmlu_prompt(line, task_name: str = None):
397
+ """MMLU prompt without letters"""
398
+ topic = line["subject"]
399
+ prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: "
400
+ prompt += line["question"] + "\nAnswer:"
401
+
402
+ return Doc(
403
+ task_name=task_name,
404
+ query=prompt,
405
+ choices=[f" {c}" for c in line["choices"]],
406
+ gold_index=line["answer"],
407
+ instruction=f"The following are questions about {topic.replace('_', ' ')}.\n",
408
+ )
409
+
410
+
411
+ # MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS}
412
+ MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS]
413
+ _TASKS_STRINGS.extend(MMLU_STRING)
414
+ _TASKS += MMLU_TASKS
415
+
416
+ ## BBH ##
417
+
418
+
419
+ class CustomBBHEvaluationTask(CustomEvaluationTask):
420
+ def __init__(
421
+ self,
422
+ name,
423
+ prompt_function="bbh_prompt",
424
+ hf_repo="lighteval/big_bench_hard",
425
+ hf_subset=None,
426
+ metric=[Metrics.exact_match],
427
+ hf_avail_splits=["train"],
428
+ evaluation_splits=["train"],
429
+ few_shots_split="train",
430
+ few_shots_select=None,
431
+ suite=None,
432
+ generation_size=4,
433
+ stop_sequence=None,
434
+ output_regex=None,
435
+ frozen=False,
436
+ ):
437
+ super().__init__(
438
+ name=name,
439
+ prompt_function=prompt_function,
440
+ hf_repo=hf_repo,
441
+ hf_subset=hf_subset,
442
+ metric=metric,
443
+ hf_avail_splits=hf_avail_splits,
444
+ evaluation_splits=evaluation_splits,
445
+ few_shots_split=few_shots_split,
446
+ few_shots_select=few_shots_select,
447
+ suite=suite,
448
+ generation_size=generation_size,
449
+ stop_sequence=stop_sequence,
450
+ output_regex=output_regex,
451
+ frozen=frozen,
452
+ )
453
+
454
+
455
+ BBH_TASKS = [
456
+ CustomBBHEvaluationTask(name="bbh:boolean_expressions", hf_subset="boolean_expressions"),
457
+ CustomBBHEvaluationTask(name="bbh:causal_judgement", hf_subset="causal_judgement"),
458
+ CustomBBHEvaluationTask(name="bbh:date_understanding", hf_subset="date_understanding"),
459
+ CustomBBHEvaluationTask(name="bbh:disambiguation_qa", hf_subset="disambiguation_qa"),
460
+ CustomBBHEvaluationTask(name="bbh:dyck_languages", hf_subset="dyck_languages"),
461
+ CustomBBHEvaluationTask(name="bbh:formal_fallacies", hf_subset="formal_fallacies"),
462
+ CustomBBHEvaluationTask(name="bbh:geometric_shapes", hf_subset="geometric_shapes"),
463
+ CustomBBHEvaluationTask(name="bbh:hyperbaton", hf_subset="hyperbaton"),
464
+ CustomBBHEvaluationTask(name="bbh:logical_deduction_five_objects", hf_subset="logical_deduction_five_objects"),
465
+ CustomBBHEvaluationTask(name="bbh:logical_deduction_seven_objects", hf_subset="logical_deduction_seven_objects"),
466
+ CustomBBHEvaluationTask(name="bbh:logical_deduction_three_objects", hf_subset="logical_deduction_three_objects"),
467
+ CustomBBHEvaluationTask(name="bbh:movie_recommendation", hf_subset="movie_recommendation"),
468
+ CustomBBHEvaluationTask(name="bbh:multistep_arithmetic_two", hf_subset="multistep_arithmetic_two"),
469
+ CustomBBHEvaluationTask(name="bbh:navigate", hf_subset="navigate"),
470
+ CustomBBHEvaluationTask(name="bbh:object_counting", hf_subset="object_counting"),
471
+ CustomBBHEvaluationTask(name="bbh:penguins_in_a_table", hf_subset="penguins_in_a_table"),
472
+ CustomBBHEvaluationTask(name="bbh:reasoning_about_colored_objects", hf_subset="reasoning_about_colored_objects"),
473
+ CustomBBHEvaluationTask(name="bbh:ruin_names", hf_subset="ruin_names"),
474
+ CustomBBHEvaluationTask(
475
+ name="bbh:salient_translation_error_detection", hf_subset="salient_translation_error_detection"
476
+ ),
477
+ CustomBBHEvaluationTask(name="bbh:snarks", hf_subset="snarks"),
478
+ CustomBBHEvaluationTask(name="bbh:sports_understanding", hf_subset="sports_understanding"),
479
+ CustomBBHEvaluationTask(name="bbh:temporal_sequences", hf_subset="temporal_sequences"),
480
+ CustomBBHEvaluationTask(
481
+ name="bbh:tracking_shuffled_objects_five_objects", hf_subset="tracking_shuffled_objects_five_objects"
482
+ ),
483
+ CustomBBHEvaluationTask(
484
+ name="bbh:tracking_shuffled_objects_seven_objects", hf_subset="tracking_shuffled_objects_seven_objects"
485
+ ),
486
+ CustomBBHEvaluationTask(
487
+ name="bbh:tracking_shuffled_objects_three_objects", hf_subset="tracking_shuffled_objects_three_objects"
488
+ ),
489
+ CustomBBHEvaluationTask(name="bbh:web_of_lies", hf_subset="web_of_lies"),
490
+ CustomBBHEvaluationTask(name="bbh:word_sorting", hf_subset="word_sorting"),
491
+ ]
492
+
493
+
494
+ def bbh_prompt(line, task_name: str = None):
495
+ return Doc(
496
+ task_name=task_name,
497
+ query=line["input"] + "\nAnswer: ",
498
+ choices=[line["target"]],
499
+ gold_index=0,
500
+ )
501
+
502
+
503
+ # BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS}
504
+ BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS]
505
+ _TASKS_STRINGS.extend(BBH_STRING)
506
+ _TASKS += BBH_TASKS
507
+
508
+
509
+ ## AGI eval ##
510
+ class CustomAGIEvalEvaluationTask(CustomEvaluationTask):
511
+ def __init__(
512
+ self,
513
+ name,
514
+ prompt_function="agi_eval_prompt_no_letters",
515
+ hf_repo="lighteval/agi_eval_en",
516
+ hf_subset=None,
517
+ # metric=[Metrics.loglikelihood_acc_single_token],
518
+ metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
519
+ hf_avail_splits=["train", "validation"],
520
+ evaluation_splits=["train"],
521
+ few_shots_split="validation",
522
+ few_shots_select=None,
523
+ suite=None,
524
+ generation_size=-1,
525
+ stop_sequence=None,
526
+ output_regex=None,
527
+ frozen=False,
528
+ ):
529
+ super().__init__(
530
+ name=name,
531
+ prompt_function=prompt_function,
532
+ hf_repo=hf_repo,
533
+ hf_subset=hf_subset,
534
+ metric=metric,
535
+ hf_avail_splits=hf_avail_splits,
536
+ evaluation_splits=evaluation_splits,
537
+ few_shots_split=few_shots_split,
538
+ few_shots_select=few_shots_select,
539
+ suite=suite,
540
+ generation_size=generation_size,
541
+ stop_sequence=stop_sequence,
542
+ output_regex=output_regex,
543
+ frozen=frozen,
544
+ )
545
+
546
+
547
+ AGIEVAL_TASKS = [
548
+ CustomAGIEvalEvaluationTask(name="agi_eval:aqua_rat", hf_subset="aqua_rat"),
549
+ CustomAGIEvalEvaluationTask(name="agi_eval:logiqa-en", hf_subset="logiqa-en"),
550
+ CustomAGIEvalEvaluationTask(name="agi_eval:lsat-ar", hf_subset="lsat-ar"),
551
+ CustomAGIEvalEvaluationTask(name="agi_eval:lsat-lr", hf_subset="lsat-lr"),
552
+ CustomAGIEvalEvaluationTask(name="agi_eval:lsat-rc", hf_subset="lsat-rc"),
553
+ CustomAGIEvalEvaluationTask(
554
+ name="agi_eval:math",
555
+ hf_subset="math",
556
+ prompt_function="agi_eval_math_prompt",
557
+ metric=[Metrics.exact_match, Metrics.quasi_exact_match2],
558
+ generation_size=40,
559
+ ),
560
+ CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"),
561
+ CustomAGIEvalEvaluationTask(name="agi_eval:sat-math", hf_subset="sat-math"),
562
+ ]
563
+
564
+
565
+ def agi_eval_math_prompt(line, task_name: str = None):
566
+ return Doc(
567
+ task_name=task_name,
568
+ query=line["question"],
569
+ choices=[line["answer"]],
570
+ gold_index=0,
571
+ instruction="",
572
+ )
573
+
574
+
575
+ def agi_eval_prompt(line, task_name: str = None):
576
+ cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]]
577
+ prompt = "The following are multiple choice questions (with answers).\n\n"
578
+ prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n"
579
+ prompt += "Answer: "
580
+
581
+ choices = LETTER_INDICES[: len(line["options"])]
582
+
583
+ output = Doc(
584
+ query=prompt,
585
+ instruction="The following are multiple choice questions (with answers).\n\n",
586
+ )
587
+
588
+ if line["label"]:
589
+ output.choices = choices
590
+ output.gold_index = LETTER_INDICES.index(line["label"].strip())
591
+ else:
592
+ output.choices = [line["answer"]]
593
+ output.gold_index = 0
594
+
595
+ return output
596
+
597
+
598
+ def agi_eval_prompt_no_letters(line, task_name: str = None):
599
+ cleaned_options = [
600
+ " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "")
601
+ for o in line["options"]
602
+ ]
603
+
604
+ output = Doc(
605
+ query=line["question"],
606
+ choices=cleaned_options,
607
+ gold_index=LETTER_INDICES.index(line["label"].strip()),
608
+ instruction="",
609
+ )
610
+
611
+ return output
612
+
613
+
614
+ # AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS}
615
+ AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS]
616
+ _TASKS_STRINGS.extend(AGIEVAL_STRING)
617
+ _TASKS += AGIEVAL_TASKS
618
+
619
+
620
+ ## HUMAN EVAL ##
621
+ # human_eval = CustomEvaluationTask(
622
+ # name="human_eval",
623
+ # prompt_function="human_eval",
624
+ # hf_repo="lighteval/human_eval",
625
+ # metric=["human_eval_pass_at_1"],
626
+ # ),
627
+
628
+
629
+ def has_generative_metrics(task: CustomEvaluationTask) -> bool:
630
+ for metric in task.metric:
631
+ if metric in NEEDS_GENERATION_ONLY:
632
+ return True
633
+ return False
634
+
635
+
636
+ EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING])
637
+
638
+ # Convert to dict for lighteval
639
+ TASKS_TABLE = [asdict(task) for task in _TASKS]
640
+ # You can have a few pre-organised groups of tasks
641
+ TASKS_GROUPS = {
642
+ "all": ",".join(t[1] for t in _TASKS_STRINGS),
643
+ "early-signal": EARLY_SIGNAL_TASKS,
644
+ "non-generatives": ",".join(t for k, t in _TASKS_STRINGS if not has_generative_metrics(k)),
645
+ "generatives": ",".join(t for k, t in _TASKS_STRINGS if has_generative_metrics(k)),
646
+ }
647
+
648
+ if __name__ == "__main__":
649
+ print(t["name"] for t in TASKS_TABLE)
650
+ print(len(TASKS_TABLE))
custom_evaluation_utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom evaluation tasks for lighteval
3
+ """
4
+ from dataclasses import dataclass
5
+ from enum import Enum, auto
6
+ from typing import Optional, Tuple, Union
7
+
8
+
9
+ class Metrics(Enum):
10
+ any_target_loglikelihood_acc = auto()
11
+ bert_score = auto()
12
+ bias = auto()
13
+ bits_per_byte = auto()
14
+ bleu = auto()
15
+ bleu_1 = auto()
16
+ bleu_4 = auto()
17
+ byte_perplexity = auto()
18
+ chrf = auto()
19
+ code_eval_APPS = auto()
20
+ code_eval_HE = auto()
21
+ copyright = auto()
22
+ disinformation = auto()
23
+ exact_match = auto()
24
+ exact_set_match = auto()
25
+ extractiveness = auto()
26
+ f1_from_bags = auto()
27
+ f1_quasi = auto()
28
+ f1_sequence = auto()
29
+ f1_set_match = auto()
30
+ faithfulness = auto()
31
+ iou_set_match = auto()
32
+ log_prob = auto()
33
+ loglikelihood_acc = auto()
34
+ loglikelihood_acc_norm = auto()
35
+ loglikelihood_acc_norm_nospace = auto()
36
+ loglikelihood_acc_norm_single_token = auto()
37
+ loglikelihood_acc_single_token = auto()
38
+ loglikelihood_f1 = auto()
39
+ loglikelihood_f1_single_token = auto()
40
+ math_quasi_exact_match = auto()
41
+ mc_taco = auto()
42
+ mcc = auto()
43
+ mcc_single_token = auto()
44
+ mrr = auto()
45
+ mrr_single_token = auto()
46
+ multi_fi_numeric = auto()
47
+ one_choice_loglikelihood_acc = auto()
48
+ perfect_exact_match = auto()
49
+ prediction_perplexity = auto()
50
+ prefix_exact_match = auto()
51
+ prefix_quasi_exact_match = auto()
52
+ quasi_exact_match = auto()
53
+ ranking = auto()
54
+ recall_at_1_single_token = auto()
55
+ recall_at_2_single_token = auto()
56
+ recall_at_1 = auto()
57
+ recall_at_2 = auto()
58
+ rouge = auto()
59
+ rouge_1 = auto()
60
+ rouge_2 = auto()
61
+ rouge_l = auto()
62
+ target_perplexity = auto()
63
+ ter = auto()
64
+ toxicity = auto()
65
+ truthfulqa_mc_metrics = auto()
66
+ word_perplexity = auto()
67
+
68
+ def __str__(self):
69
+ return self.name.replace("_at_", "@")
70
+
71
+
72
+ NEEDS_GENERATION_ONLY = [
73
+ "perfect_exact_match",
74
+ "exact_match",
75
+ "quasi_exact_match",
76
+ "quasi_exact_match2",
77
+ "prefix_exact_match",
78
+ "prefix_quasi_exact_match",
79
+ "math_quasi_exact_match",
80
+ "iou_set_match",
81
+ "exact_set_match",
82
+ "f1_sequence",
83
+ "f1_quasi",
84
+ "f1_set_match",
85
+ "f1_from_bags",
86
+ "chrf",
87
+ "ter",
88
+ "rouge",
89
+ "rouge_1",
90
+ "rouge_2",
91
+ "rouge_l",
92
+ "faithfulness",
93
+ "extractiveness",
94
+ "bert_score",
95
+ "bleu",
96
+ "bleu_1",
97
+ "bleu_4",
98
+ "bias",
99
+ "toxicity",
100
+ "code_eval_HE",
101
+ "code_eval_APPS",
102
+ "copyright",
103
+ ]
104
+
105
+
106
+ @dataclass(unsafe_hash=True)
107
+ class CustomEvaluationTask:
108
+ name: str
109
+ prompt_function: str
110
+ hf_repo: str
111
+ hf_subset: str
112
+ metric: Tuple[Union[str, Metrics]]
113
+ hf_avail_splits: Optional[Tuple[str]] = None
114
+ evaluation_splits: Optional[Tuple[str]] = None
115
+ few_shots_split: Optional[str] = None
116
+ few_shots_select: Optional[str] = None
117
+ generation_size: int = -1
118
+ stop_sequence: Optional[Tuple[str]] = None
119
+ output_regex: Optional[str] = None
120
+
121
+ frozen: bool = False
122
+ suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task
123
+
124
+ def __post_init__(self):
125
+ self.metric = [str(m) for m in self.metric]
126
+ if self.suite is None:
127
+ self.suite = ["custom"]
128
+ if self.hf_avail_splits is None:
129
+ self.hf_avail_splits = ["train", "validation", "test"]
130
+ if self.evaluation_splits is None:
131
+ self.evaluation_splits = ["validation"]
132
+ if self.stop_sequence is None:
133
+ self.stop_sequence = ["\n"]
134
+
135
+ # Convert list to tuple for hashing
136
+ self.metric = tuple(self.metric)
137
+ self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits else None
138
+ self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits else None
139
+ self.suite = tuple(self.suite) if self.suite else None
140
+ self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence else None
141
+
142
+
143
+ @dataclass(unsafe_hash=True)
144
+ class BigCodeEvaluationTask:
145
+ name: str
146
+ bigcode_task: str
147
+ bigcode_task_kwargs: Optional[dict] = None
148
+ n_samples: int = 1
149
+ prefix: Optional[str] = None
150
+
151
+ suite: Tuple[str] = None
152
+
153
+ def __post_init__(self):
154
+ if self.suite is None:
155
+ self.suite = ("bigcode",)
156
+
157
+ # Convert list to tuple for hashing
158
+ self.suite = tuple(self.suite)
lighteval_eval_config.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints: null
2
+ data: null
3
+ experiment_logger: null
4
+ general: null
5
+ kill_switch_path: null
6
+ lighteval:
7
+ batch_size: 24
8
+ checkpoints_path: null
9
+ generation: null
10
+ logging:
11
+ hub_repo_details: null
12
+ hub_repo_results: null
13
+ hub_repo_tensorboard: HuggingFaceBR4/thomwolf-nanotron-mistral-7b
14
+ local_output_path: /scratch/thomwolf/lighteval/nanotron-mistral-7b
15
+ push_details_to_hub: false
16
+ push_results_to_hub: false
17
+ push_results_to_tensorboard: true
18
+ tensorboard_metric_prefix: e
19
+ parallelism:
20
+ dp: 4
21
+ pp: 1
22
+ pp_engine: 1f1b
23
+ recompute_granularity: null
24
+ tp: 2
25
+ tp_linear_async_communication: false
26
+ tp_mode: ALL_REDUCE
27
+ slurm: null
28
+ slurm_script_dir: null
29
+ slurm_template: null
30
+ tasks:
31
+ custom_tasks_file: ./custom_evaluation_tasks.py
32
+ dataset_loading_processes: 8
33
+ max_samples: 1000
34
+ multichoice_continuations_start_space: null
35
+ no_multichoice_continuations_start_space: null
36
+ num_fewshot_seeds: null
37
+ tasks: early-signal
38
+ logging: null
39
+ model: null
40
+ optimizer: null
41
+ parallelism: null
42
+ profiler: null
43
+ s3_upload: null
44
+ tokenizer: null
45
+ tokens: null
run_evals.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nanotron Inference Script
3
+
4
+ Usage:
5
+ ```
6
+ export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
7
+ torchrun --nproc_per_node=8 run_evals.py --checkpoint-config-path ./pretrained/Mistral-7B-v0.1/config.yaml \
8
+ --lighteval-override ./lighteval_eval_config.yaml
9
+ ```
10
+ """
11
+ # flake8: noqa: C901
12
+ import argparse
13
+ import os
14
+ import random
15
+ import time
16
+ from dataclasses import asdict
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+ from huggingface_hub import HFSummaryWriter
22
+ from lighteval.evaluator import evaluate, make_results_table
23
+ from lighteval.logging.evaluation_tracker import EvaluationTracker
24
+ from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
25
+ from lighteval.logging.info_loggers import (
26
+ DetailsLogger,
27
+ )
28
+ from lighteval.models.model_loader import ModelInfo
29
+ from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
30
+ from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector
31
+ from nanotron import distributed as dist
32
+ from nanotron import logging
33
+ from nanotron.config import get_config_from_file
34
+ from nanotron.logging import get_logger, log_rank
35
+ from nanotron.parallel.context import ParallelContext
36
+ from nanotron.utils import local_ranks_zero_first
37
+
38
+ from brrr.config import BrrrConfig
39
+ from brrr.experiment_loggers import flatten_dict, obj_to_markdown
40
+ from brrr.s3_checkpoints import fs_copy
41
+ from brrr.utils import check_env
42
+
43
+ from lighteval.models.brrr_models import BRRRModel
44
+
45
+ from modeling_mistral import MistralForTraining
46
+ from config_mistral import MistralConfig
47
+
48
+ logger = get_logger(__name__)
49
+
50
+ TOKEN = os.getenv("HF_TOKEN")
51
+ CACHE_DIR = os.getenv("HF_HOME", "/scratch")
52
+
53
+
54
+ def get_parser():
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument(
57
+ "--checkpoint-config-path",
58
+ type=str,
59
+ required=True,
60
+ help="Path to the brr checkpoint YAML or python config file, potentially on S3",
61
+ )
62
+ parser.add_argument(
63
+ "--lighteval-override",
64
+ type=str,
65
+ help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
66
+ )
67
+ parser.add_argument(
68
+ "--tokenizer",
69
+ type=str,
70
+ help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)",
71
+ )
72
+ parser.add_argument(
73
+ "--s5cmd-path",
74
+ type=str,
75
+ default="/admin/home/thomwolf/miniconda3/envs/b4r/bin/s5cmd",
76
+ help="Path to s5cmd install",
77
+ )
78
+ parser.add_argument(
79
+ "--s5cmd-numworkers",
80
+ type=int,
81
+ default=64,
82
+ help="s5cmd num workers (optional)",
83
+ )
84
+ parser.add_argument(
85
+ "--s5cmd-concurrency",
86
+ type=int,
87
+ default=10,
88
+ help="s5cmd concurrency (optional)",
89
+ )
90
+ parser.add_argument(
91
+ "--cache-dir",
92
+ type=str,
93
+ default="",
94
+ help="Cache directory",
95
+ )
96
+
97
+ return parser
98
+
99
+
100
+ def push_results_to_wandb( # noqa: C901
101
+ config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
102
+ ):
103
+ # config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
104
+ lighteval_config = config.lighteval
105
+ try:
106
+ global_step = config.general.step
107
+ except ValueError:
108
+ global_step = 0
109
+ if config.lighteval.logging.tensorboard_metric_prefix is not None:
110
+ prefix = config.lighteval.logging.tensorboard_metric_prefix
111
+ else:
112
+ prefix = "eval"
113
+ output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
114
+ output_dir_tb.mkdir(parents=True, exist_ok=True)
115
+
116
+ os.environ["WANDB_DISABLE_SERVICE"] = "True"
117
+ import wandb
118
+
119
+ wandb.tensorboard.patch(root_logdir=config.lighteval.logging.local_output_path)
120
+ hlog("Starting wandb with WANDB_DISABLE_SERVICE=True")
121
+ wandb.init(
122
+ project=config.lighteval.wandb.wandb_project,
123
+ entity=config.lighteval.wandb.wandb_entity,
124
+ name=config.lighteval.wandb.wandb_run_name,
125
+ config=config.as_dict(),
126
+ # sync_tensorboard=True,
127
+ resume=True,
128
+ )
129
+ wb_dict = {}
130
+ bench_averages = {}
131
+ for name, values in results.items():
132
+ splited_name = name.split("|")
133
+ if len(splited_name) == 3:
134
+ _, task_name, _ = splited_name
135
+ else:
136
+ task_name = name
137
+ bench_suite = None
138
+ if ":" in task_name:
139
+ bench_suite = task_name.split(":")[0] # e.g. MMLU
140
+ hlog(f"bench_suite {bench_suite} in {task_name}")
141
+ for metric, value in values.items():
142
+ if "stderr" in metric:
143
+ continue
144
+ if bench_suite not in bench_averages:
145
+ bench_averages[bench_suite] = {}
146
+ bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
147
+ hlog(f"Pushing {task_name} {values} to tensorboard")
148
+ for metric, value in values.items():
149
+ if "stderr" in metric:
150
+ wb_dict[f"stderr_{metric}/{task_name}"] = value
151
+ elif bench_suite is not None:
152
+ wb_dict[f"{bench_suite}-{metric}/{task_name}"] = value
153
+ else:
154
+ wb_dict[f"{metric}/{task_name}"] = value
155
+ # e.g. MMLU
156
+ for name, values in bench_averages.items():
157
+ for metric, values in values.items():
158
+ hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
159
+ wb_dict[f"{metric}/{name}"] = sum(values) / len(values)
160
+
161
+ for task_name, task_details in details.items():
162
+ if len(task_details) <= 1:
163
+ continue
164
+ columns = list(flatten_dict(asdict(task_details[0])).keys())
165
+ table = wandb.Table(columns=columns)
166
+ table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[0])).values()])
167
+ table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[1])).values()])
168
+ wandb.log({f"eval_details_{task_name}": table}, step=global_step, commit=False)
169
+
170
+ wandb.log(dict(wb_dict.items()), step=global_step, commit=True)
171
+
172
+ # tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
173
+
174
+ # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
175
+ # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
176
+ # See: https://github.com/tensorflow/tensorboard/issues/5958
177
+ # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
178
+
179
+ hlog(f"Pushed to wandb" f" at {output_dir_tb} and global_step {global_step}")
180
+
181
+
182
+ def push_results_to_tensorboard( # noqa: C901
183
+ config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
184
+ ):
185
+ # config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
186
+ lighteval_config = config.lighteval
187
+ try:
188
+ global_step = config.general.step
189
+ except ValueError:
190
+ global_step = 0
191
+ if config.lighteval.logging.tensorboard_metric_prefix is not None:
192
+ prefix = config.lighteval.logging.tensorboard_metric_prefix
193
+ else:
194
+ prefix = "eval"
195
+ output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
196
+ output_dir_tb.mkdir(parents=True, exist_ok=True)
197
+ tb_context = HFSummaryWriter(
198
+ logdir=str(output_dir_tb),
199
+ repo_id=lighteval_config.logging.hub_repo_tensorboard,
200
+ repo_private=True,
201
+ path_in_repo="tb",
202
+ commit_every=6000, # Very long time so that we can change our files names and trigger push ourselves (see below)
203
+ )
204
+ bench_averages = {}
205
+ for name, values in results.items():
206
+ splited_name = name.split("|")
207
+ if len(splited_name) == 3:
208
+ _, task_name, _ = splited_name
209
+ else:
210
+ task_name = name
211
+ bench_suite = None
212
+ if ":" in task_name:
213
+ bench_suite = task_name.split(":")[0] # e.g. MMLU
214
+ hlog(f"bench_suite {bench_suite} in {task_name}")
215
+ for metric, value in values.items():
216
+ if "stderr" in metric:
217
+ continue
218
+ if bench_suite not in bench_averages:
219
+ bench_averages[bench_suite] = {}
220
+ bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
221
+ hlog(f"Pushing {task_name} {values} to tensorboard")
222
+ for metric, value in values.items():
223
+ if "stderr" in metric:
224
+ tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step)
225
+ elif bench_suite is not None:
226
+ tb_context.add_scalar(f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step)
227
+ else:
228
+ tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step)
229
+ # e.g. MMLU
230
+ for name, values in bench_averages.items():
231
+ for metric, values in values.items():
232
+ hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
233
+ tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step)
234
+
235
+ tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step)
236
+ # tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
237
+
238
+ for task_name, task_details in details.items():
239
+ tb_context.add_text(
240
+ f"eval_details_{task_name}",
241
+ obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}),
242
+ global_step=global_step,
243
+ )
244
+
245
+ # We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
246
+ # This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
247
+ # See: https://github.com/tensorflow/tensorboard/issues/5958
248
+ # But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
249
+
250
+ tb_context.close() # flushes the unfinished write operations
251
+ time.sleep(5)
252
+ files = os.listdir(output_dir_tb)
253
+ for file in files:
254
+ os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}"))
255
+
256
+ # Now we can push to the hub
257
+ tb_context.scheduler.trigger()
258
+ hlog(
259
+ f"Pushed to tensorboard at https://huggingface.co/tensorboard/{lighteval_config.logging.hub_repo_tensorboard}/"
260
+ f" at {output_dir_tb} and global_step {global_step}"
261
+ )
262
+
263
+
264
+ @htrack()
265
+ def main(args):
266
+ cache_dir = args.cache_dir or CACHE_DIR
267
+ check_env()
268
+
269
+ dist.initialize_torch_distributed()
270
+
271
+ with htrack_block("get config"):
272
+ if not args.checkpoint_config_path.endswith(".yaml"):
273
+ raise ValueError("The checkpoint path should point to a YAML file")
274
+ local_config_path = args.checkpoint_config_path
275
+ if args.checkpoint_config_path.startswith("s3:/"):
276
+ local_config_path = args.checkpoint_config_path.replace("s3:/", cache_dir)
277
+ with local_ranks_zero_first():
278
+ if os.environ.get("LOCAL_RANK", None) == "0":
279
+ os.makedirs(os.path.dirname(local_config_path), exist_ok=True)
280
+ fs_copy(args.checkpoint_config_path, local_config_path)
281
+
282
+ brrr_config: BrrrConfig = get_config_from_file(local_config_path, config_class=BrrrConfig, model_config_class=MistralConfig)
283
+
284
+ if args.lighteval_override:
285
+ local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
286
+ if args.lighteval_override.startswith("s3:/"):
287
+ local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
288
+ with local_ranks_zero_first():
289
+ if os.environ.get("LOCAL_RANK", None) == "0":
290
+ os.makedirs(os.path.dirname(local_override_path), exist_ok=True)
291
+ fs_copy(args.lighteval_override, local_override_path)
292
+ lighteval_brrr_config: BrrrConfig = get_config_from_file(local_override_path, config_class=BrrrConfig)
293
+ lighteval_config = lighteval_brrr_config.lighteval
294
+ brrr_config.lighteval = lighteval_config
295
+ else:
296
+ local_override_path = ""
297
+ lighteval_config = brrr_config.lighteval
298
+
299
+ parallel_context = ParallelContext(
300
+ tensor_parallel_size=lighteval_config.parallelism.tp,
301
+ pipeline_parallel_size=lighteval_config.parallelism.pp,
302
+ data_parallel_size=lighteval_config.parallelism.dp,
303
+ )
304
+
305
+ evaluation_tracker = EvaluationTracker(token=TOKEN)
306
+ evaluation_tracker.general_config_logger.log_args_info(
307
+ num_fewshot_seeds=1,
308
+ override_batch_size=None,
309
+ max_samples=lighteval_config.tasks.max_samples,
310
+ job_id=os.environ.get("SLURM_JOB_ID", None),
311
+ config=brrr_config.as_dict(),
312
+ )
313
+
314
+ with htrack_block("Test all gather"):
315
+ hlog("Test gather tensor")
316
+ # Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading
317
+ log_rank(
318
+ f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}",
319
+ logger=logger,
320
+ level=logging.WARNING,
321
+ group=parallel_context.dp_pg,
322
+ rank=0,
323
+ )
324
+ test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda"))
325
+ test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())]
326
+ dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False)
327
+ dist.barrier()
328
+ log_rank(
329
+ f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}",
330
+ logger=logger,
331
+ level=logging.WARNING,
332
+ group=parallel_context.dp_pg,
333
+ rank=0,
334
+ )
335
+
336
+ del test_tensor_list
337
+ del test_tensor
338
+
339
+ with htrack_block("Model loading"):
340
+ # We need to load the model in the main process first to avoid downloading the model multiple times
341
+ model = BRRRModel(
342
+ checkpoint_path=args.checkpoint_config_path.replace("config.yaml", ""),
343
+ model_args=brrr_config.model,
344
+ tokenizer=brrr_config.tokenizer,
345
+ parallel_context=parallel_context,
346
+ parallel_config=lighteval_config.parallelism,
347
+ lighteval_config=lighteval_config,
348
+ batch_size=lighteval_config.batch_size,
349
+ cache_dir=os.environ.get("HF_HOME", "/scratch"),
350
+ debug_one_layer_model=False,
351
+ s5cmd_path=args.s5cmd_path,
352
+ s5cmd_numworkers=args.s5cmd_numworkers,
353
+ s5cmd_concurrency=args.s5cmd_concurrency,
354
+ model_class=MistralForTraining
355
+ )
356
+ model_info = ModelInfo(model_name=f"{brrr_config.general.run}/{brrr_config.general.step}")
357
+ evaluation_tracker.general_config_logger.log_model_info(model_info)
358
+
359
+ with htrack_block("Tasks loading"):
360
+ with local_ranks_zero_first():
361
+ tasks_selection = lighteval_config.tasks.tasks
362
+ if lighteval_config.tasks.custom_tasks_file:
363
+ _, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
364
+ if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
365
+ tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]
366
+
367
+ task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
368
+ task_dict = Registry(cache_dir=cache_dir).get_task_dict(
369
+ task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
370
+ )
371
+ # Loading all the dataset in a distributed manner
372
+ LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
373
+
374
+ evaluation_tracker.task_config_logger.log(task_dict)
375
+
376
+ hlog("Loading documents, and requests")
377
+ requests, docs = create_requests_from_tasks(
378
+ task_dict=task_dict,
379
+ fewshot_dict=few_shots_dict,
380
+ num_fewshot_seeds=lighteval_config.tasks.num_fewshot_seeds or 1,
381
+ lm=model,
382
+ max_samples=lighteval_config.tasks.max_samples,
383
+ evaluation_tracker=evaluation_tracker,
384
+ use_chat_template=False
385
+ )
386
+
387
+ with htrack_block("Setting seeds and waiting for all processes"):
388
+ hlog(f"setting seed to {1234} for random and numpy")
389
+ random.seed(1234)
390
+ np.random.seed(1234)
391
+ dist.barrier()
392
+
393
+ with htrack_block("Evaluation"):
394
+ hlog(f"Evaluate on {len(task_names_list)} tasks.")
395
+ evaluation_tracker = evaluate(
396
+ lm=model,
397
+ requests_dict=requests,
398
+ docs=docs,
399
+ task_dict=task_dict,
400
+ override_bs=lighteval_config.batch_size,
401
+ evaluation_tracker=evaluation_tracker,
402
+ )
403
+
404
+ if dist.get_rank(parallel_context.world_pg) == 0:
405
+ with htrack_block("Compiling and saving results"):
406
+ evaluation_tracker.general_config_logger.log_end_time()
407
+ evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
408
+ evaluation_tracker.details_logger.aggregate()
409
+
410
+ if lighteval_config.logging.local_output_path:
411
+ evaluation_tracker.save(
412
+ output_dir=lighteval_config.logging.local_output_path,
413
+ push_results_to_hub=lighteval_config.logging.push_results_to_hub,
414
+ push_details_to_hub=lighteval_config.logging.push_details_to_hub,
415
+ public=False,
416
+ push_results_to_tensorboard=lighteval_config.logging.push_results_to_tensorboard,
417
+ )
418
+
419
+ if lighteval_config.logging.push_results_to_tensorboard:
420
+ push_results_to_tensorboard(
421
+ config=brrr_config,
422
+ results=evaluation_tracker.metrics_logger.metric_aggregated,
423
+ details=evaluation_tracker.details_logger.details,
424
+ )
425
+ if lighteval_config.wandb is not None:
426
+ push_results_to_wandb(
427
+ config=brrr_config,
428
+ results=evaluation_tracker.metrics_logger.metric_aggregated,
429
+ details=evaluation_tracker.details_logger.details,
430
+ )
431
+
432
+ final_dict = evaluation_tracker.generate_final_dict()
433
+
434
+ hlog(make_results_table(final_dict))
435
+
436
+ return final_dict
437
+
438
+
439
+ if __name__ == "__main__":
440
+ parser = get_parser()
441
+ args, unknowns = parser.parse_known_args()
442
+ main(args)
run_train.py CHANGED
@@ -8,11 +8,11 @@ torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
8
  ```
9
  """
10
  import argparse
 
11
 
12
- from config_tiny_mistral import MistralConfig
13
  from dataloader import get_dataloader
14
  from modeling_mistral import MistralForTraining
15
- from nanotron.trainer import DistributedTrainer
16
 
17
 
18
  def get_args():
 
8
  ```
9
  """
10
  import argparse
11
+ from nanotron.trainer import DistributedTrainer
12
 
 
13
  from dataloader import get_dataloader
14
  from modeling_mistral import MistralForTraining
15
+ from config_tiny_mistral import MistralConfig
16
 
17
 
18
  def get_args():