add eval code
Browse files- custom_evaluation_tasks.py +650 -0
- custom_evaluation_utils.py +158 -0
- lighteval_eval_config.yaml +45 -0
- run_evals.py +442 -0
- run_train.py +2 -2
custom_evaluation_tasks.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ruff: noqa: F405, F403, F401
|
2 |
+
"""
|
3 |
+
Custom evaluation tasks for lighteval
|
4 |
+
|
5 |
+
This file generally create just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval.
|
6 |
+
"""
|
7 |
+
import re
|
8 |
+
from dataclasses import asdict
|
9 |
+
from typing import Dict, List, Tuple
|
10 |
+
|
11 |
+
from custom_evaluation_utils import *
|
12 |
+
from lighteval.tasks.requests import Doc
|
13 |
+
|
14 |
+
# fmt: off
|
15 |
+
LETTER_INDICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"]
|
16 |
+
# fmt: on
|
17 |
+
|
18 |
+
_TASKS_STRINGS: List[Tuple[CustomEvaluationTask, str]] = []
|
19 |
+
_TASKS: List[CustomEvaluationTask] = []
|
20 |
+
|
21 |
+
## COMMON_SENSE_REASONING_TASKS ##
|
22 |
+
COMMON_SENSE_REASONING_TASKS = [
|
23 |
+
CustomEvaluationTask(
|
24 |
+
name="hellaswag",
|
25 |
+
prompt_function="hellaswag_prompt",
|
26 |
+
hf_repo="hellaswag",
|
27 |
+
hf_subset="default",
|
28 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
29 |
+
),
|
30 |
+
CustomEvaluationTask(
|
31 |
+
name="winogrande",
|
32 |
+
prompt_function="winogrande",
|
33 |
+
hf_repo="winogrande",
|
34 |
+
hf_subset="winogrande_xl",
|
35 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
36 |
+
),
|
37 |
+
CustomEvaluationTask(
|
38 |
+
name="piqa",
|
39 |
+
prompt_function="piqa_harness",
|
40 |
+
hf_repo="piqa",
|
41 |
+
hf_subset="plain_text",
|
42 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
43 |
+
),
|
44 |
+
CustomEvaluationTask(
|
45 |
+
name="siqa",
|
46 |
+
prompt_function="siqa_prompt",
|
47 |
+
hf_repo="lighteval/siqa",
|
48 |
+
hf_subset="default",
|
49 |
+
hf_avail_splits=["train", "validation"],
|
50 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
51 |
+
),
|
52 |
+
CustomEvaluationTask(
|
53 |
+
name="openbookqa",
|
54 |
+
prompt_function="openbookqa",
|
55 |
+
hf_repo="openbookqa",
|
56 |
+
hf_subset="main",
|
57 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
58 |
+
),
|
59 |
+
CustomEvaluationTask(
|
60 |
+
name="arc:easy",
|
61 |
+
prompt_function="arc",
|
62 |
+
hf_repo="ai2_arc",
|
63 |
+
hf_subset="ARC-Easy",
|
64 |
+
evaluation_splits=["test"],
|
65 |
+
generation_size=1,
|
66 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
67 |
+
),
|
68 |
+
CustomEvaluationTask(
|
69 |
+
name="arc:challenge",
|
70 |
+
prompt_function="arc",
|
71 |
+
hf_repo="ai2_arc",
|
72 |
+
hf_subset="ARC-Challenge",
|
73 |
+
evaluation_splits=["test"],
|
74 |
+
generation_size=1,
|
75 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
76 |
+
),
|
77 |
+
CustomEvaluationTask(
|
78 |
+
name="commonsense_qa",
|
79 |
+
prompt_function="commonsense_qa_prompt",
|
80 |
+
hf_repo="commonsense_qa",
|
81 |
+
hf_subset="default",
|
82 |
+
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
|
83 |
+
),
|
84 |
+
]
|
85 |
+
|
86 |
+
|
87 |
+
def commonsense_qa_prompt(line, task_name: str = None):
|
88 |
+
return Doc(
|
89 |
+
task_name=task_name,
|
90 |
+
query=line["question"],
|
91 |
+
choices=[f" {c}" for c in line["choices"]["text"]],
|
92 |
+
gold_index=LETTER_INDICES.index(line["answerKey"].strip()),
|
93 |
+
instruction="",
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def siqa_prompt(line, task_name: str = None):
|
98 |
+
return Doc(
|
99 |
+
task_name=task_name,
|
100 |
+
query=line["context"] + " " + line["question"],
|
101 |
+
choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]],
|
102 |
+
gold_index=int(line["label"]) - 1,
|
103 |
+
instruction="",
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
def hellaswag_prompt(line, task_name: str = None):
|
108 |
+
def preprocess(text):
|
109 |
+
"""Comes from AiHarness"""
|
110 |
+
# text = text.strip()
|
111 |
+
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
|
112 |
+
text = text.replace(" [title]", ". ")
|
113 |
+
text = re.sub("\\[.*?\\]", "", text)
|
114 |
+
text = text.replace(" ", " ")
|
115 |
+
return text
|
116 |
+
|
117 |
+
ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} "
|
118 |
+
return Doc(
|
119 |
+
task_name=task_name,
|
120 |
+
query=preprocess(line["activity_label"] + ": " + ctx),
|
121 |
+
choices=[" " + preprocess(ending) for ending in line["endings"]],
|
122 |
+
gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test
|
123 |
+
# "metric": "choices_loglikelihood",
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
# 0 short for common sense
|
128 |
+
COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS]
|
129 |
+
_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING)
|
130 |
+
_TASKS += COMMON_SENSE_REASONING_TASKS
|
131 |
+
|
132 |
+
## WORLD_KNOWLEDGE_TASKS ##
|
133 |
+
|
134 |
+
WORLD_KNOWLEDGE_TASKS = [
|
135 |
+
CustomEvaluationTask(
|
136 |
+
name="trivia_qa",
|
137 |
+
prompt_function="triviaqa",
|
138 |
+
hf_repo="trivia_qa",
|
139 |
+
hf_subset="rc.nocontext",
|
140 |
+
metric=[Metrics.quasi_exact_match2],
|
141 |
+
generation_size=20,
|
142 |
+
stop_sequence=["\n", ".", ","],
|
143 |
+
),
|
144 |
+
CustomEvaluationTask(
|
145 |
+
name="natural_questions",
|
146 |
+
prompt_function="natural_questions_prompt",
|
147 |
+
hf_repo="lighteval/natural_questions_clean",
|
148 |
+
hf_subset="default",
|
149 |
+
metric=[Metrics.quasi_exact_match2],
|
150 |
+
generation_size=20,
|
151 |
+
stop_sequence=["\n", ".", ","],
|
152 |
+
),
|
153 |
+
]
|
154 |
+
|
155 |
+
|
156 |
+
def natural_questions_prompt(line, task_name: str = None):
|
157 |
+
return Doc(
|
158 |
+
task_name=task_name,
|
159 |
+
query=line["question"] + "?\nAnswer: ",
|
160 |
+
choices=[line["short_answers"]],
|
161 |
+
gold_index=0,
|
162 |
+
instruction="",
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS]
|
167 |
+
# WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS}
|
168 |
+
_TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING)
|
169 |
+
_TASKS += WORLD_KNOWLEDGE_TASKS
|
170 |
+
|
171 |
+
## Reading comprehension ##
|
172 |
+
|
173 |
+
READING_COMP_TASKS = [
|
174 |
+
CustomEvaluationTask(
|
175 |
+
name="super_glue:boolq",
|
176 |
+
prompt_function="boolq_prompt",
|
177 |
+
hf_repo="super_glue",
|
178 |
+
hf_subset="boolq",
|
179 |
+
metric=[Metrics.target_perplexity],
|
180 |
+
),
|
181 |
+
CustomEvaluationTask(
|
182 |
+
name="quac",
|
183 |
+
prompt_function="quac",
|
184 |
+
hf_repo="lighteval/quac_helm",
|
185 |
+
hf_subset="default",
|
186 |
+
metric=[Metrics.quasi_exact_match2],
|
187 |
+
generation_size=20,
|
188 |
+
stop_sequence=["\n", ".", ","],
|
189 |
+
),
|
190 |
+
]
|
191 |
+
|
192 |
+
|
193 |
+
def boolq_prompt(line, task_name: str = None):
|
194 |
+
return Doc(
|
195 |
+
task_name=task_name,
|
196 |
+
query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:",
|
197 |
+
choices=[" No", " Yes"], # Only gold
|
198 |
+
gold_index=int(line["label"]),
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS]
|
203 |
+
_TASKS_STRINGS.extend(READING_COMP_STRING)
|
204 |
+
_TASKS += READING_COMP_TASKS
|
205 |
+
|
206 |
+
|
207 |
+
## MATH ##
|
208 |
+
class CustomMathEvaluationTask(CustomEvaluationTask):
|
209 |
+
"""Custom class for math tasks with all the defaults set"""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
name,
|
214 |
+
prompt_function="math",
|
215 |
+
hf_repo="lighteval/MATH",
|
216 |
+
hf_subset=None,
|
217 |
+
metric=[Metrics.math_quasi_exact_match],
|
218 |
+
hf_avail_splits=None,
|
219 |
+
evaluation_splits=["test"],
|
220 |
+
few_shots_split=None,
|
221 |
+
few_shots_select=None,
|
222 |
+
suite=["custom"],
|
223 |
+
generation_size=40,
|
224 |
+
stop_sequence=None,
|
225 |
+
output_regex=None,
|
226 |
+
frozen=False,
|
227 |
+
):
|
228 |
+
super().__init__(
|
229 |
+
name=name,
|
230 |
+
prompt_function=prompt_function,
|
231 |
+
hf_repo=hf_repo,
|
232 |
+
hf_subset=hf_subset,
|
233 |
+
metric=metric,
|
234 |
+
hf_avail_splits=hf_avail_splits,
|
235 |
+
evaluation_splits=evaluation_splits,
|
236 |
+
few_shots_split=few_shots_split,
|
237 |
+
few_shots_select=few_shots_select,
|
238 |
+
suite=suite,
|
239 |
+
generation_size=generation_size,
|
240 |
+
stop_sequence=stop_sequence,
|
241 |
+
output_regex=output_regex,
|
242 |
+
frozen=frozen,
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
MATH_TASKS = [
|
247 |
+
CustomMathEvaluationTask(name="math:algebra", hf_subset="algebra"),
|
248 |
+
CustomMathEvaluationTask(name="math:counting_and_probability", hf_subset="counting_and_probability"),
|
249 |
+
CustomMathEvaluationTask(name="math:geometry", hf_subset="geometry"),
|
250 |
+
CustomMathEvaluationTask(name="math:intermediate_algebra", hf_subset="intermediate_algebra"),
|
251 |
+
CustomMathEvaluationTask(name="math:number_theory", hf_subset="number_theory"),
|
252 |
+
CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"),
|
253 |
+
CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"),
|
254 |
+
]
|
255 |
+
GSM8K = CustomEvaluationTask(
|
256 |
+
name="gsm8k",
|
257 |
+
prompt_function="gsm8k",
|
258 |
+
hf_repo="gsm8k",
|
259 |
+
hf_subset="main",
|
260 |
+
hf_avail_splits=["train", "test"],
|
261 |
+
evaluation_splits=["test"],
|
262 |
+
metric=[Metrics.perfect_exact_match],
|
263 |
+
generation_size=10,
|
264 |
+
stop_sequence=["\n"],
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
MATH_STRING = [(t, f"custom|{t.name}|4|1") for t in MATH_TASKS]
|
269 |
+
GSM8K_STRING = [(GSM8K, f"custom|{GSM8K.name}|8|1")]
|
270 |
+
_TASKS_STRINGS.extend(MATH_STRING)
|
271 |
+
_TASKS_STRINGS.extend(GSM8K_STRING)
|
272 |
+
_TASKS += MATH_TASKS + [GSM8K]
|
273 |
+
|
274 |
+
|
275 |
+
## MMLU ##
|
276 |
+
class CustomMMLUEvaluationTask(CustomEvaluationTask):
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
name,
|
280 |
+
prompt_function="mmlu_prompt",
|
281 |
+
hf_repo="lighteval/mmlu",
|
282 |
+
hf_subset=None,
|
283 |
+
# metric=[Metrics.loglikelihood_acc_single_token],
|
284 |
+
metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
|
285 |
+
hf_avail_splits=None,
|
286 |
+
evaluation_splits=["test"],
|
287 |
+
few_shots_split="dev",
|
288 |
+
few_shots_select=None,
|
289 |
+
suite=None,
|
290 |
+
generation_size=-1,
|
291 |
+
stop_sequence=None,
|
292 |
+
output_regex=None,
|
293 |
+
frozen=False,
|
294 |
+
):
|
295 |
+
super().__init__(
|
296 |
+
name=name,
|
297 |
+
prompt_function=prompt_function,
|
298 |
+
hf_repo=hf_repo,
|
299 |
+
hf_subset=hf_subset,
|
300 |
+
metric=metric,
|
301 |
+
hf_avail_splits=hf_avail_splits,
|
302 |
+
evaluation_splits=evaluation_splits,
|
303 |
+
few_shots_split=few_shots_split,
|
304 |
+
few_shots_select=few_shots_select,
|
305 |
+
suite=suite,
|
306 |
+
generation_size=generation_size,
|
307 |
+
stop_sequence=stop_sequence,
|
308 |
+
output_regex=output_regex,
|
309 |
+
frozen=frozen,
|
310 |
+
)
|
311 |
+
|
312 |
+
|
313 |
+
MMLU_TASKS = [
|
314 |
+
CustomMMLUEvaluationTask(name="mmlu:abstract_algebra", hf_subset="abstract_algebra"),
|
315 |
+
CustomMMLUEvaluationTask(name="mmlu:anatomy", hf_subset="anatomy"),
|
316 |
+
CustomMMLUEvaluationTask(name="mmlu:astronomy", hf_subset="astronomy"),
|
317 |
+
CustomMMLUEvaluationTask(name="mmlu:business_ethics", hf_subset="business_ethics"),
|
318 |
+
CustomMMLUEvaluationTask(name="mmlu:clinical_knowledge", hf_subset="clinical_knowledge"),
|
319 |
+
CustomMMLUEvaluationTask(name="mmlu:college_biology", hf_subset="college_biology"),
|
320 |
+
CustomMMLUEvaluationTask(name="mmlu:college_chemistry", hf_subset="college_chemistry"),
|
321 |
+
CustomMMLUEvaluationTask(name="mmlu:college_computer_science", hf_subset="college_computer_science"),
|
322 |
+
CustomMMLUEvaluationTask(name="mmlu:college_mathematics", hf_subset="college_mathematics"),
|
323 |
+
CustomMMLUEvaluationTask(name="mmlu:college_medicine", hf_subset="college_medicine"),
|
324 |
+
CustomMMLUEvaluationTask(name="mmlu:college_physics", hf_subset="college_physics"),
|
325 |
+
CustomMMLUEvaluationTask(name="mmlu:computer_security", hf_subset="computer_security"),
|
326 |
+
CustomMMLUEvaluationTask(name="mmlu:conceptual_physics", hf_subset="conceptual_physics"),
|
327 |
+
CustomMMLUEvaluationTask(name="mmlu:econometrics", hf_subset="econometrics"),
|
328 |
+
CustomMMLUEvaluationTask(name="mmlu:electrical_engineering", hf_subset="electrical_engineering"),
|
329 |
+
CustomMMLUEvaluationTask(name="mmlu:elementary_mathematics", hf_subset="elementary_mathematics"),
|
330 |
+
CustomMMLUEvaluationTask(name="mmlu:formal_logic", hf_subset="formal_logic"),
|
331 |
+
CustomMMLUEvaluationTask(name="mmlu:global_facts", hf_subset="global_facts"),
|
332 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_biology", hf_subset="high_school_biology"),
|
333 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_chemistry", hf_subset="high_school_chemistry"),
|
334 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_computer_science", hf_subset="high_school_computer_science"),
|
335 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_european_history", hf_subset="high_school_european_history"),
|
336 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_geography", hf_subset="high_school_geography"),
|
337 |
+
CustomMMLUEvaluationTask(
|
338 |
+
name="mmlu:high_school_government_and_politics", hf_subset="high_school_government_and_politics"
|
339 |
+
),
|
340 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_macroeconomics", hf_subset="high_school_macroeconomics"),
|
341 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_mathematics", hf_subset="high_school_mathematics"),
|
342 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_microeconomics", hf_subset="high_school_microeconomics"),
|
343 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_physics", hf_subset="high_school_physics"),
|
344 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_psychology", hf_subset="high_school_psychology"),
|
345 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_statistics", hf_subset="high_school_statistics"),
|
346 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_us_history", hf_subset="high_school_us_history"),
|
347 |
+
CustomMMLUEvaluationTask(name="mmlu:high_school_world_history", hf_subset="high_school_world_history"),
|
348 |
+
CustomMMLUEvaluationTask(name="mmlu:human_aging", hf_subset="human_aging"),
|
349 |
+
CustomMMLUEvaluationTask(name="mmlu:human_sexuality", hf_subset="human_sexuality"),
|
350 |
+
CustomMMLUEvaluationTask(name="mmlu:international_law", hf_subset="international_law"),
|
351 |
+
CustomMMLUEvaluationTask(name="mmlu:jurisprudence", hf_subset="jurisprudence"),
|
352 |
+
CustomMMLUEvaluationTask(name="mmlu:logical_fallacies", hf_subset="logical_fallacies"),
|
353 |
+
CustomMMLUEvaluationTask(name="mmlu:machine_learning", hf_subset="machine_learning"),
|
354 |
+
CustomMMLUEvaluationTask(name="mmlu:management", hf_subset="management"),
|
355 |
+
CustomMMLUEvaluationTask(name="mmlu:marketing", hf_subset="marketing"),
|
356 |
+
CustomMMLUEvaluationTask(name="mmlu:medical_genetics", hf_subset="medical_genetics"),
|
357 |
+
CustomMMLUEvaluationTask(name="mmlu:miscellaneous", hf_subset="miscellaneous"),
|
358 |
+
CustomMMLUEvaluationTask(name="mmlu:moral_disputes", hf_subset="moral_disputes"),
|
359 |
+
CustomMMLUEvaluationTask(name="mmlu:moral_scenarios", hf_subset="moral_scenarios"),
|
360 |
+
CustomMMLUEvaluationTask(name="mmlu:nutrition", hf_subset="nutrition"),
|
361 |
+
CustomMMLUEvaluationTask(name="mmlu:philosophy", hf_subset="philosophy"),
|
362 |
+
CustomMMLUEvaluationTask(name="mmlu:prehistory", hf_subset="prehistory"),
|
363 |
+
CustomMMLUEvaluationTask(name="mmlu:professional_accounting", hf_subset="professional_accounting"),
|
364 |
+
CustomMMLUEvaluationTask(name="mmlu:professional_law", hf_subset="professional_law"),
|
365 |
+
CustomMMLUEvaluationTask(name="mmlu:professional_medicine", hf_subset="professional_medicine"),
|
366 |
+
CustomMMLUEvaluationTask(name="mmlu:professional_psychology", hf_subset="professional_psychology"),
|
367 |
+
CustomMMLUEvaluationTask(name="mmlu:public_relations", hf_subset="public_relations"),
|
368 |
+
CustomMMLUEvaluationTask(name="mmlu:security_studies", hf_subset="security_studies"),
|
369 |
+
CustomMMLUEvaluationTask(name="mmlu:sociology", hf_subset="sociology"),
|
370 |
+
CustomMMLUEvaluationTask(name="mmlu:us_foreign_policy", hf_subset="us_foreign_policy"),
|
371 |
+
CustomMMLUEvaluationTask(name="mmlu:virology", hf_subset="virology"),
|
372 |
+
CustomMMLUEvaluationTask(name="mmlu:world_religions", hf_subset="world_religions"),
|
373 |
+
]
|
374 |
+
|
375 |
+
|
376 |
+
def mmlu_harness(line, task_name: str = None):
|
377 |
+
topic = line["subject"]
|
378 |
+
prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n"
|
379 |
+
prompt += line["question"] + "\n"
|
380 |
+
prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])])
|
381 |
+
prompt += "Answer:"
|
382 |
+
|
383 |
+
gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"]
|
384 |
+
"__few_shots" in line and line["__few_shots"] is True # We are adding few shots
|
385 |
+
|
386 |
+
return Doc(
|
387 |
+
task_name=task_name,
|
388 |
+
query=prompt,
|
389 |
+
choices=[" A", " B", " C", " D"],
|
390 |
+
target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix],
|
391 |
+
gold_index=gold_ix,
|
392 |
+
instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n",
|
393 |
+
)
|
394 |
+
|
395 |
+
|
396 |
+
def mmlu_prompt(line, task_name: str = None):
|
397 |
+
"""MMLU prompt without letters"""
|
398 |
+
topic = line["subject"]
|
399 |
+
prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: "
|
400 |
+
prompt += line["question"] + "\nAnswer:"
|
401 |
+
|
402 |
+
return Doc(
|
403 |
+
task_name=task_name,
|
404 |
+
query=prompt,
|
405 |
+
choices=[f" {c}" for c in line["choices"]],
|
406 |
+
gold_index=line["answer"],
|
407 |
+
instruction=f"The following are questions about {topic.replace('_', ' ')}.\n",
|
408 |
+
)
|
409 |
+
|
410 |
+
|
411 |
+
# MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS}
|
412 |
+
MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS]
|
413 |
+
_TASKS_STRINGS.extend(MMLU_STRING)
|
414 |
+
_TASKS += MMLU_TASKS
|
415 |
+
|
416 |
+
## BBH ##
|
417 |
+
|
418 |
+
|
419 |
+
class CustomBBHEvaluationTask(CustomEvaluationTask):
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
name,
|
423 |
+
prompt_function="bbh_prompt",
|
424 |
+
hf_repo="lighteval/big_bench_hard",
|
425 |
+
hf_subset=None,
|
426 |
+
metric=[Metrics.exact_match],
|
427 |
+
hf_avail_splits=["train"],
|
428 |
+
evaluation_splits=["train"],
|
429 |
+
few_shots_split="train",
|
430 |
+
few_shots_select=None,
|
431 |
+
suite=None,
|
432 |
+
generation_size=4,
|
433 |
+
stop_sequence=None,
|
434 |
+
output_regex=None,
|
435 |
+
frozen=False,
|
436 |
+
):
|
437 |
+
super().__init__(
|
438 |
+
name=name,
|
439 |
+
prompt_function=prompt_function,
|
440 |
+
hf_repo=hf_repo,
|
441 |
+
hf_subset=hf_subset,
|
442 |
+
metric=metric,
|
443 |
+
hf_avail_splits=hf_avail_splits,
|
444 |
+
evaluation_splits=evaluation_splits,
|
445 |
+
few_shots_split=few_shots_split,
|
446 |
+
few_shots_select=few_shots_select,
|
447 |
+
suite=suite,
|
448 |
+
generation_size=generation_size,
|
449 |
+
stop_sequence=stop_sequence,
|
450 |
+
output_regex=output_regex,
|
451 |
+
frozen=frozen,
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
BBH_TASKS = [
|
456 |
+
CustomBBHEvaluationTask(name="bbh:boolean_expressions", hf_subset="boolean_expressions"),
|
457 |
+
CustomBBHEvaluationTask(name="bbh:causal_judgement", hf_subset="causal_judgement"),
|
458 |
+
CustomBBHEvaluationTask(name="bbh:date_understanding", hf_subset="date_understanding"),
|
459 |
+
CustomBBHEvaluationTask(name="bbh:disambiguation_qa", hf_subset="disambiguation_qa"),
|
460 |
+
CustomBBHEvaluationTask(name="bbh:dyck_languages", hf_subset="dyck_languages"),
|
461 |
+
CustomBBHEvaluationTask(name="bbh:formal_fallacies", hf_subset="formal_fallacies"),
|
462 |
+
CustomBBHEvaluationTask(name="bbh:geometric_shapes", hf_subset="geometric_shapes"),
|
463 |
+
CustomBBHEvaluationTask(name="bbh:hyperbaton", hf_subset="hyperbaton"),
|
464 |
+
CustomBBHEvaluationTask(name="bbh:logical_deduction_five_objects", hf_subset="logical_deduction_five_objects"),
|
465 |
+
CustomBBHEvaluationTask(name="bbh:logical_deduction_seven_objects", hf_subset="logical_deduction_seven_objects"),
|
466 |
+
CustomBBHEvaluationTask(name="bbh:logical_deduction_three_objects", hf_subset="logical_deduction_three_objects"),
|
467 |
+
CustomBBHEvaluationTask(name="bbh:movie_recommendation", hf_subset="movie_recommendation"),
|
468 |
+
CustomBBHEvaluationTask(name="bbh:multistep_arithmetic_two", hf_subset="multistep_arithmetic_two"),
|
469 |
+
CustomBBHEvaluationTask(name="bbh:navigate", hf_subset="navigate"),
|
470 |
+
CustomBBHEvaluationTask(name="bbh:object_counting", hf_subset="object_counting"),
|
471 |
+
CustomBBHEvaluationTask(name="bbh:penguins_in_a_table", hf_subset="penguins_in_a_table"),
|
472 |
+
CustomBBHEvaluationTask(name="bbh:reasoning_about_colored_objects", hf_subset="reasoning_about_colored_objects"),
|
473 |
+
CustomBBHEvaluationTask(name="bbh:ruin_names", hf_subset="ruin_names"),
|
474 |
+
CustomBBHEvaluationTask(
|
475 |
+
name="bbh:salient_translation_error_detection", hf_subset="salient_translation_error_detection"
|
476 |
+
),
|
477 |
+
CustomBBHEvaluationTask(name="bbh:snarks", hf_subset="snarks"),
|
478 |
+
CustomBBHEvaluationTask(name="bbh:sports_understanding", hf_subset="sports_understanding"),
|
479 |
+
CustomBBHEvaluationTask(name="bbh:temporal_sequences", hf_subset="temporal_sequences"),
|
480 |
+
CustomBBHEvaluationTask(
|
481 |
+
name="bbh:tracking_shuffled_objects_five_objects", hf_subset="tracking_shuffled_objects_five_objects"
|
482 |
+
),
|
483 |
+
CustomBBHEvaluationTask(
|
484 |
+
name="bbh:tracking_shuffled_objects_seven_objects", hf_subset="tracking_shuffled_objects_seven_objects"
|
485 |
+
),
|
486 |
+
CustomBBHEvaluationTask(
|
487 |
+
name="bbh:tracking_shuffled_objects_three_objects", hf_subset="tracking_shuffled_objects_three_objects"
|
488 |
+
),
|
489 |
+
CustomBBHEvaluationTask(name="bbh:web_of_lies", hf_subset="web_of_lies"),
|
490 |
+
CustomBBHEvaluationTask(name="bbh:word_sorting", hf_subset="word_sorting"),
|
491 |
+
]
|
492 |
+
|
493 |
+
|
494 |
+
def bbh_prompt(line, task_name: str = None):
|
495 |
+
return Doc(
|
496 |
+
task_name=task_name,
|
497 |
+
query=line["input"] + "\nAnswer: ",
|
498 |
+
choices=[line["target"]],
|
499 |
+
gold_index=0,
|
500 |
+
)
|
501 |
+
|
502 |
+
|
503 |
+
# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS}
|
504 |
+
BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS]
|
505 |
+
_TASKS_STRINGS.extend(BBH_STRING)
|
506 |
+
_TASKS += BBH_TASKS
|
507 |
+
|
508 |
+
|
509 |
+
## AGI eval ##
|
510 |
+
class CustomAGIEvalEvaluationTask(CustomEvaluationTask):
|
511 |
+
def __init__(
|
512 |
+
self,
|
513 |
+
name,
|
514 |
+
prompt_function="agi_eval_prompt_no_letters",
|
515 |
+
hf_repo="lighteval/agi_eval_en",
|
516 |
+
hf_subset=None,
|
517 |
+
# metric=[Metrics.loglikelihood_acc_single_token],
|
518 |
+
metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace],
|
519 |
+
hf_avail_splits=["train", "validation"],
|
520 |
+
evaluation_splits=["train"],
|
521 |
+
few_shots_split="validation",
|
522 |
+
few_shots_select=None,
|
523 |
+
suite=None,
|
524 |
+
generation_size=-1,
|
525 |
+
stop_sequence=None,
|
526 |
+
output_regex=None,
|
527 |
+
frozen=False,
|
528 |
+
):
|
529 |
+
super().__init__(
|
530 |
+
name=name,
|
531 |
+
prompt_function=prompt_function,
|
532 |
+
hf_repo=hf_repo,
|
533 |
+
hf_subset=hf_subset,
|
534 |
+
metric=metric,
|
535 |
+
hf_avail_splits=hf_avail_splits,
|
536 |
+
evaluation_splits=evaluation_splits,
|
537 |
+
few_shots_split=few_shots_split,
|
538 |
+
few_shots_select=few_shots_select,
|
539 |
+
suite=suite,
|
540 |
+
generation_size=generation_size,
|
541 |
+
stop_sequence=stop_sequence,
|
542 |
+
output_regex=output_regex,
|
543 |
+
frozen=frozen,
|
544 |
+
)
|
545 |
+
|
546 |
+
|
547 |
+
AGIEVAL_TASKS = [
|
548 |
+
CustomAGIEvalEvaluationTask(name="agi_eval:aqua_rat", hf_subset="aqua_rat"),
|
549 |
+
CustomAGIEvalEvaluationTask(name="agi_eval:logiqa-en", hf_subset="logiqa-en"),
|
550 |
+
CustomAGIEvalEvaluationTask(name="agi_eval:lsat-ar", hf_subset="lsat-ar"),
|
551 |
+
CustomAGIEvalEvaluationTask(name="agi_eval:lsat-lr", hf_subset="lsat-lr"),
|
552 |
+
CustomAGIEvalEvaluationTask(name="agi_eval:lsat-rc", hf_subset="lsat-rc"),
|
553 |
+
CustomAGIEvalEvaluationTask(
|
554 |
+
name="agi_eval:math",
|
555 |
+
hf_subset="math",
|
556 |
+
prompt_function="agi_eval_math_prompt",
|
557 |
+
metric=[Metrics.exact_match, Metrics.quasi_exact_match2],
|
558 |
+
generation_size=40,
|
559 |
+
),
|
560 |
+
CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"),
|
561 |
+
CustomAGIEvalEvaluationTask(name="agi_eval:sat-math", hf_subset="sat-math"),
|
562 |
+
]
|
563 |
+
|
564 |
+
|
565 |
+
def agi_eval_math_prompt(line, task_name: str = None):
|
566 |
+
return Doc(
|
567 |
+
task_name=task_name,
|
568 |
+
query=line["question"],
|
569 |
+
choices=[line["answer"]],
|
570 |
+
gold_index=0,
|
571 |
+
instruction="",
|
572 |
+
)
|
573 |
+
|
574 |
+
|
575 |
+
def agi_eval_prompt(line, task_name: str = None):
|
576 |
+
cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]]
|
577 |
+
prompt = "The following are multiple choice questions (with answers).\n\n"
|
578 |
+
prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n"
|
579 |
+
prompt += "Answer: "
|
580 |
+
|
581 |
+
choices = LETTER_INDICES[: len(line["options"])]
|
582 |
+
|
583 |
+
output = Doc(
|
584 |
+
query=prompt,
|
585 |
+
instruction="The following are multiple choice questions (with answers).\n\n",
|
586 |
+
)
|
587 |
+
|
588 |
+
if line["label"]:
|
589 |
+
output.choices = choices
|
590 |
+
output.gold_index = LETTER_INDICES.index(line["label"].strip())
|
591 |
+
else:
|
592 |
+
output.choices = [line["answer"]]
|
593 |
+
output.gold_index = 0
|
594 |
+
|
595 |
+
return output
|
596 |
+
|
597 |
+
|
598 |
+
def agi_eval_prompt_no_letters(line, task_name: str = None):
|
599 |
+
cleaned_options = [
|
600 |
+
" " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "")
|
601 |
+
for o in line["options"]
|
602 |
+
]
|
603 |
+
|
604 |
+
output = Doc(
|
605 |
+
query=line["question"],
|
606 |
+
choices=cleaned_options,
|
607 |
+
gold_index=LETTER_INDICES.index(line["label"].strip()),
|
608 |
+
instruction="",
|
609 |
+
)
|
610 |
+
|
611 |
+
return output
|
612 |
+
|
613 |
+
|
614 |
+
# AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS}
|
615 |
+
AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS]
|
616 |
+
_TASKS_STRINGS.extend(AGIEVAL_STRING)
|
617 |
+
_TASKS += AGIEVAL_TASKS
|
618 |
+
|
619 |
+
|
620 |
+
## HUMAN EVAL ##
|
621 |
+
# human_eval = CustomEvaluationTask(
|
622 |
+
# name="human_eval",
|
623 |
+
# prompt_function="human_eval",
|
624 |
+
# hf_repo="lighteval/human_eval",
|
625 |
+
# metric=["human_eval_pass_at_1"],
|
626 |
+
# ),
|
627 |
+
|
628 |
+
|
629 |
+
def has_generative_metrics(task: CustomEvaluationTask) -> bool:
|
630 |
+
for metric in task.metric:
|
631 |
+
if metric in NEEDS_GENERATION_ONLY:
|
632 |
+
return True
|
633 |
+
return False
|
634 |
+
|
635 |
+
|
636 |
+
EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING])
|
637 |
+
|
638 |
+
# Convert to dict for lighteval
|
639 |
+
TASKS_TABLE = [asdict(task) for task in _TASKS]
|
640 |
+
# You can have a few pre-organised groups of tasks
|
641 |
+
TASKS_GROUPS = {
|
642 |
+
"all": ",".join(t[1] for t in _TASKS_STRINGS),
|
643 |
+
"early-signal": EARLY_SIGNAL_TASKS,
|
644 |
+
"non-generatives": ",".join(t for k, t in _TASKS_STRINGS if not has_generative_metrics(k)),
|
645 |
+
"generatives": ",".join(t for k, t in _TASKS_STRINGS if has_generative_metrics(k)),
|
646 |
+
}
|
647 |
+
|
648 |
+
if __name__ == "__main__":
|
649 |
+
print(t["name"] for t in TASKS_TABLE)
|
650 |
+
print(len(TASKS_TABLE))
|
custom_evaluation_utils.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Custom evaluation tasks for lighteval
|
3 |
+
"""
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from enum import Enum, auto
|
6 |
+
from typing import Optional, Tuple, Union
|
7 |
+
|
8 |
+
|
9 |
+
class Metrics(Enum):
|
10 |
+
any_target_loglikelihood_acc = auto()
|
11 |
+
bert_score = auto()
|
12 |
+
bias = auto()
|
13 |
+
bits_per_byte = auto()
|
14 |
+
bleu = auto()
|
15 |
+
bleu_1 = auto()
|
16 |
+
bleu_4 = auto()
|
17 |
+
byte_perplexity = auto()
|
18 |
+
chrf = auto()
|
19 |
+
code_eval_APPS = auto()
|
20 |
+
code_eval_HE = auto()
|
21 |
+
copyright = auto()
|
22 |
+
disinformation = auto()
|
23 |
+
exact_match = auto()
|
24 |
+
exact_set_match = auto()
|
25 |
+
extractiveness = auto()
|
26 |
+
f1_from_bags = auto()
|
27 |
+
f1_quasi = auto()
|
28 |
+
f1_sequence = auto()
|
29 |
+
f1_set_match = auto()
|
30 |
+
faithfulness = auto()
|
31 |
+
iou_set_match = auto()
|
32 |
+
log_prob = auto()
|
33 |
+
loglikelihood_acc = auto()
|
34 |
+
loglikelihood_acc_norm = auto()
|
35 |
+
loglikelihood_acc_norm_nospace = auto()
|
36 |
+
loglikelihood_acc_norm_single_token = auto()
|
37 |
+
loglikelihood_acc_single_token = auto()
|
38 |
+
loglikelihood_f1 = auto()
|
39 |
+
loglikelihood_f1_single_token = auto()
|
40 |
+
math_quasi_exact_match = auto()
|
41 |
+
mc_taco = auto()
|
42 |
+
mcc = auto()
|
43 |
+
mcc_single_token = auto()
|
44 |
+
mrr = auto()
|
45 |
+
mrr_single_token = auto()
|
46 |
+
multi_fi_numeric = auto()
|
47 |
+
one_choice_loglikelihood_acc = auto()
|
48 |
+
perfect_exact_match = auto()
|
49 |
+
prediction_perplexity = auto()
|
50 |
+
prefix_exact_match = auto()
|
51 |
+
prefix_quasi_exact_match = auto()
|
52 |
+
quasi_exact_match = auto()
|
53 |
+
ranking = auto()
|
54 |
+
recall_at_1_single_token = auto()
|
55 |
+
recall_at_2_single_token = auto()
|
56 |
+
recall_at_1 = auto()
|
57 |
+
recall_at_2 = auto()
|
58 |
+
rouge = auto()
|
59 |
+
rouge_1 = auto()
|
60 |
+
rouge_2 = auto()
|
61 |
+
rouge_l = auto()
|
62 |
+
target_perplexity = auto()
|
63 |
+
ter = auto()
|
64 |
+
toxicity = auto()
|
65 |
+
truthfulqa_mc_metrics = auto()
|
66 |
+
word_perplexity = auto()
|
67 |
+
|
68 |
+
def __str__(self):
|
69 |
+
return self.name.replace("_at_", "@")
|
70 |
+
|
71 |
+
|
72 |
+
NEEDS_GENERATION_ONLY = [
|
73 |
+
"perfect_exact_match",
|
74 |
+
"exact_match",
|
75 |
+
"quasi_exact_match",
|
76 |
+
"quasi_exact_match2",
|
77 |
+
"prefix_exact_match",
|
78 |
+
"prefix_quasi_exact_match",
|
79 |
+
"math_quasi_exact_match",
|
80 |
+
"iou_set_match",
|
81 |
+
"exact_set_match",
|
82 |
+
"f1_sequence",
|
83 |
+
"f1_quasi",
|
84 |
+
"f1_set_match",
|
85 |
+
"f1_from_bags",
|
86 |
+
"chrf",
|
87 |
+
"ter",
|
88 |
+
"rouge",
|
89 |
+
"rouge_1",
|
90 |
+
"rouge_2",
|
91 |
+
"rouge_l",
|
92 |
+
"faithfulness",
|
93 |
+
"extractiveness",
|
94 |
+
"bert_score",
|
95 |
+
"bleu",
|
96 |
+
"bleu_1",
|
97 |
+
"bleu_4",
|
98 |
+
"bias",
|
99 |
+
"toxicity",
|
100 |
+
"code_eval_HE",
|
101 |
+
"code_eval_APPS",
|
102 |
+
"copyright",
|
103 |
+
]
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass(unsafe_hash=True)
|
107 |
+
class CustomEvaluationTask:
|
108 |
+
name: str
|
109 |
+
prompt_function: str
|
110 |
+
hf_repo: str
|
111 |
+
hf_subset: str
|
112 |
+
metric: Tuple[Union[str, Metrics]]
|
113 |
+
hf_avail_splits: Optional[Tuple[str]] = None
|
114 |
+
evaluation_splits: Optional[Tuple[str]] = None
|
115 |
+
few_shots_split: Optional[str] = None
|
116 |
+
few_shots_select: Optional[str] = None
|
117 |
+
generation_size: int = -1
|
118 |
+
stop_sequence: Optional[Tuple[str]] = None
|
119 |
+
output_regex: Optional[str] = None
|
120 |
+
|
121 |
+
frozen: bool = False
|
122 |
+
suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task
|
123 |
+
|
124 |
+
def __post_init__(self):
|
125 |
+
self.metric = [str(m) for m in self.metric]
|
126 |
+
if self.suite is None:
|
127 |
+
self.suite = ["custom"]
|
128 |
+
if self.hf_avail_splits is None:
|
129 |
+
self.hf_avail_splits = ["train", "validation", "test"]
|
130 |
+
if self.evaluation_splits is None:
|
131 |
+
self.evaluation_splits = ["validation"]
|
132 |
+
if self.stop_sequence is None:
|
133 |
+
self.stop_sequence = ["\n"]
|
134 |
+
|
135 |
+
# Convert list to tuple for hashing
|
136 |
+
self.metric = tuple(self.metric)
|
137 |
+
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits else None
|
138 |
+
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits else None
|
139 |
+
self.suite = tuple(self.suite) if self.suite else None
|
140 |
+
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence else None
|
141 |
+
|
142 |
+
|
143 |
+
@dataclass(unsafe_hash=True)
|
144 |
+
class BigCodeEvaluationTask:
|
145 |
+
name: str
|
146 |
+
bigcode_task: str
|
147 |
+
bigcode_task_kwargs: Optional[dict] = None
|
148 |
+
n_samples: int = 1
|
149 |
+
prefix: Optional[str] = None
|
150 |
+
|
151 |
+
suite: Tuple[str] = None
|
152 |
+
|
153 |
+
def __post_init__(self):
|
154 |
+
if self.suite is None:
|
155 |
+
self.suite = ("bigcode",)
|
156 |
+
|
157 |
+
# Convert list to tuple for hashing
|
158 |
+
self.suite = tuple(self.suite)
|
lighteval_eval_config.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoints: null
|
2 |
+
data: null
|
3 |
+
experiment_logger: null
|
4 |
+
general: null
|
5 |
+
kill_switch_path: null
|
6 |
+
lighteval:
|
7 |
+
batch_size: 24
|
8 |
+
checkpoints_path: null
|
9 |
+
generation: null
|
10 |
+
logging:
|
11 |
+
hub_repo_details: null
|
12 |
+
hub_repo_results: null
|
13 |
+
hub_repo_tensorboard: HuggingFaceBR4/thomwolf-nanotron-mistral-7b
|
14 |
+
local_output_path: /scratch/thomwolf/lighteval/nanotron-mistral-7b
|
15 |
+
push_details_to_hub: false
|
16 |
+
push_results_to_hub: false
|
17 |
+
push_results_to_tensorboard: true
|
18 |
+
tensorboard_metric_prefix: e
|
19 |
+
parallelism:
|
20 |
+
dp: 4
|
21 |
+
pp: 1
|
22 |
+
pp_engine: 1f1b
|
23 |
+
recompute_granularity: null
|
24 |
+
tp: 2
|
25 |
+
tp_linear_async_communication: false
|
26 |
+
tp_mode: ALL_REDUCE
|
27 |
+
slurm: null
|
28 |
+
slurm_script_dir: null
|
29 |
+
slurm_template: null
|
30 |
+
tasks:
|
31 |
+
custom_tasks_file: ./custom_evaluation_tasks.py
|
32 |
+
dataset_loading_processes: 8
|
33 |
+
max_samples: 1000
|
34 |
+
multichoice_continuations_start_space: null
|
35 |
+
no_multichoice_continuations_start_space: null
|
36 |
+
num_fewshot_seeds: null
|
37 |
+
tasks: early-signal
|
38 |
+
logging: null
|
39 |
+
model: null
|
40 |
+
optimizer: null
|
41 |
+
parallelism: null
|
42 |
+
profiler: null
|
43 |
+
s3_upload: null
|
44 |
+
tokenizer: null
|
45 |
+
tokens: null
|
run_evals.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Nanotron Inference Script
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
```
|
6 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
|
7 |
+
torchrun --nproc_per_node=8 run_evals.py --checkpoint-config-path ./pretrained/Mistral-7B-v0.1/config.yaml \
|
8 |
+
--lighteval-override ./lighteval_eval_config.yaml
|
9 |
+
```
|
10 |
+
"""
|
11 |
+
# flake8: noqa: C901
|
12 |
+
import argparse
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
import time
|
16 |
+
from dataclasses import asdict
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from huggingface_hub import HFSummaryWriter
|
22 |
+
from lighteval.evaluator import evaluate, make_results_table
|
23 |
+
from lighteval.logging.evaluation_tracker import EvaluationTracker
|
24 |
+
from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
|
25 |
+
from lighteval.logging.info_loggers import (
|
26 |
+
DetailsLogger,
|
27 |
+
)
|
28 |
+
from lighteval.models.model_loader import ModelInfo
|
29 |
+
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
|
30 |
+
from lighteval.tasks.registry import Registry, get_custom_tasks, taskinfo_selector
|
31 |
+
from nanotron import distributed as dist
|
32 |
+
from nanotron import logging
|
33 |
+
from nanotron.config import get_config_from_file
|
34 |
+
from nanotron.logging import get_logger, log_rank
|
35 |
+
from nanotron.parallel.context import ParallelContext
|
36 |
+
from nanotron.utils import local_ranks_zero_first
|
37 |
+
|
38 |
+
from brrr.config import BrrrConfig
|
39 |
+
from brrr.experiment_loggers import flatten_dict, obj_to_markdown
|
40 |
+
from brrr.s3_checkpoints import fs_copy
|
41 |
+
from brrr.utils import check_env
|
42 |
+
|
43 |
+
from lighteval.models.brrr_models import BRRRModel
|
44 |
+
|
45 |
+
from modeling_mistral import MistralForTraining
|
46 |
+
from config_mistral import MistralConfig
|
47 |
+
|
48 |
+
logger = get_logger(__name__)
|
49 |
+
|
50 |
+
TOKEN = os.getenv("HF_TOKEN")
|
51 |
+
CACHE_DIR = os.getenv("HF_HOME", "/scratch")
|
52 |
+
|
53 |
+
|
54 |
+
def get_parser():
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
parser.add_argument(
|
57 |
+
"--checkpoint-config-path",
|
58 |
+
type=str,
|
59 |
+
required=True,
|
60 |
+
help="Path to the brr checkpoint YAML or python config file, potentially on S3",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--lighteval-override",
|
64 |
+
type=str,
|
65 |
+
help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--tokenizer",
|
69 |
+
type=str,
|
70 |
+
help="Local or hub path of an optional tokenizer (if not indicated in the checkpoint)",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--s5cmd-path",
|
74 |
+
type=str,
|
75 |
+
default="/admin/home/thomwolf/miniconda3/envs/b4r/bin/s5cmd",
|
76 |
+
help="Path to s5cmd install",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--s5cmd-numworkers",
|
80 |
+
type=int,
|
81 |
+
default=64,
|
82 |
+
help="s5cmd num workers (optional)",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--s5cmd-concurrency",
|
86 |
+
type=int,
|
87 |
+
default=10,
|
88 |
+
help="s5cmd concurrency (optional)",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--cache-dir",
|
92 |
+
type=str,
|
93 |
+
default="",
|
94 |
+
help="Cache directory",
|
95 |
+
)
|
96 |
+
|
97 |
+
return parser
|
98 |
+
|
99 |
+
|
100 |
+
def push_results_to_wandb( # noqa: C901
|
101 |
+
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
|
102 |
+
):
|
103 |
+
# config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
|
104 |
+
lighteval_config = config.lighteval
|
105 |
+
try:
|
106 |
+
global_step = config.general.step
|
107 |
+
except ValueError:
|
108 |
+
global_step = 0
|
109 |
+
if config.lighteval.logging.tensorboard_metric_prefix is not None:
|
110 |
+
prefix = config.lighteval.logging.tensorboard_metric_prefix
|
111 |
+
else:
|
112 |
+
prefix = "eval"
|
113 |
+
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
|
114 |
+
output_dir_tb.mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
os.environ["WANDB_DISABLE_SERVICE"] = "True"
|
117 |
+
import wandb
|
118 |
+
|
119 |
+
wandb.tensorboard.patch(root_logdir=config.lighteval.logging.local_output_path)
|
120 |
+
hlog("Starting wandb with WANDB_DISABLE_SERVICE=True")
|
121 |
+
wandb.init(
|
122 |
+
project=config.lighteval.wandb.wandb_project,
|
123 |
+
entity=config.lighteval.wandb.wandb_entity,
|
124 |
+
name=config.lighteval.wandb.wandb_run_name,
|
125 |
+
config=config.as_dict(),
|
126 |
+
# sync_tensorboard=True,
|
127 |
+
resume=True,
|
128 |
+
)
|
129 |
+
wb_dict = {}
|
130 |
+
bench_averages = {}
|
131 |
+
for name, values in results.items():
|
132 |
+
splited_name = name.split("|")
|
133 |
+
if len(splited_name) == 3:
|
134 |
+
_, task_name, _ = splited_name
|
135 |
+
else:
|
136 |
+
task_name = name
|
137 |
+
bench_suite = None
|
138 |
+
if ":" in task_name:
|
139 |
+
bench_suite = task_name.split(":")[0] # e.g. MMLU
|
140 |
+
hlog(f"bench_suite {bench_suite} in {task_name}")
|
141 |
+
for metric, value in values.items():
|
142 |
+
if "stderr" in metric:
|
143 |
+
continue
|
144 |
+
if bench_suite not in bench_averages:
|
145 |
+
bench_averages[bench_suite] = {}
|
146 |
+
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
|
147 |
+
hlog(f"Pushing {task_name} {values} to tensorboard")
|
148 |
+
for metric, value in values.items():
|
149 |
+
if "stderr" in metric:
|
150 |
+
wb_dict[f"stderr_{metric}/{task_name}"] = value
|
151 |
+
elif bench_suite is not None:
|
152 |
+
wb_dict[f"{bench_suite}-{metric}/{task_name}"] = value
|
153 |
+
else:
|
154 |
+
wb_dict[f"{metric}/{task_name}"] = value
|
155 |
+
# e.g. MMLU
|
156 |
+
for name, values in bench_averages.items():
|
157 |
+
for metric, values in values.items():
|
158 |
+
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
|
159 |
+
wb_dict[f"{metric}/{name}"] = sum(values) / len(values)
|
160 |
+
|
161 |
+
for task_name, task_details in details.items():
|
162 |
+
if len(task_details) <= 1:
|
163 |
+
continue
|
164 |
+
columns = list(flatten_dict(asdict(task_details[0])).keys())
|
165 |
+
table = wandb.Table(columns=columns)
|
166 |
+
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[0])).values()])
|
167 |
+
table.add_data(*[str(v) for v in flatten_dict(asdict(task_details[1])).values()])
|
168 |
+
wandb.log({f"eval_details_{task_name}": table}, step=global_step, commit=False)
|
169 |
+
|
170 |
+
wandb.log(dict(wb_dict.items()), step=global_step, commit=True)
|
171 |
+
|
172 |
+
# tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
|
173 |
+
|
174 |
+
# We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
|
175 |
+
# This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
|
176 |
+
# See: https://github.com/tensorflow/tensorboard/issues/5958
|
177 |
+
# But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
|
178 |
+
|
179 |
+
hlog(f"Pushed to wandb" f" at {output_dir_tb} and global_step {global_step}")
|
180 |
+
|
181 |
+
|
182 |
+
def push_results_to_tensorboard( # noqa: C901
|
183 |
+
config: BrrrConfig, results: dict[str, dict[str, float]], details: dict[str, DetailsLogger.CompiledDetail]
|
184 |
+
):
|
185 |
+
# config: BrrrConfig = get_config_from_dict(config, config_class=BrrrConfig)
|
186 |
+
lighteval_config = config.lighteval
|
187 |
+
try:
|
188 |
+
global_step = config.general.step
|
189 |
+
except ValueError:
|
190 |
+
global_step = 0
|
191 |
+
if config.lighteval.logging.tensorboard_metric_prefix is not None:
|
192 |
+
prefix = config.lighteval.logging.tensorboard_metric_prefix
|
193 |
+
else:
|
194 |
+
prefix = "eval"
|
195 |
+
output_dir_tb = Path(lighteval_config.logging.local_output_path) / "tb" / (config.general.run + "_" + prefix)
|
196 |
+
output_dir_tb.mkdir(parents=True, exist_ok=True)
|
197 |
+
tb_context = HFSummaryWriter(
|
198 |
+
logdir=str(output_dir_tb),
|
199 |
+
repo_id=lighteval_config.logging.hub_repo_tensorboard,
|
200 |
+
repo_private=True,
|
201 |
+
path_in_repo="tb",
|
202 |
+
commit_every=6000, # Very long time so that we can change our files names and trigger push ourselves (see below)
|
203 |
+
)
|
204 |
+
bench_averages = {}
|
205 |
+
for name, values in results.items():
|
206 |
+
splited_name = name.split("|")
|
207 |
+
if len(splited_name) == 3:
|
208 |
+
_, task_name, _ = splited_name
|
209 |
+
else:
|
210 |
+
task_name = name
|
211 |
+
bench_suite = None
|
212 |
+
if ":" in task_name:
|
213 |
+
bench_suite = task_name.split(":")[0] # e.g. MMLU
|
214 |
+
hlog(f"bench_suite {bench_suite} in {task_name}")
|
215 |
+
for metric, value in values.items():
|
216 |
+
if "stderr" in metric:
|
217 |
+
continue
|
218 |
+
if bench_suite not in bench_averages:
|
219 |
+
bench_averages[bench_suite] = {}
|
220 |
+
bench_averages[bench_suite][metric] = bench_averages[bench_suite].get(metric, []) + [float(value)]
|
221 |
+
hlog(f"Pushing {task_name} {values} to tensorboard")
|
222 |
+
for metric, value in values.items():
|
223 |
+
if "stderr" in metric:
|
224 |
+
tb_context.add_scalar(f"stderr_{prefix}/{task_name}/{metric}", value, global_step=global_step)
|
225 |
+
elif bench_suite is not None:
|
226 |
+
tb_context.add_scalar(f"{prefix}_{bench_suite}/{task_name}/{metric}", value, global_step=global_step)
|
227 |
+
else:
|
228 |
+
tb_context.add_scalar(f"{prefix}/{task_name}/{metric}", value, global_step=global_step)
|
229 |
+
# e.g. MMLU
|
230 |
+
for name, values in bench_averages.items():
|
231 |
+
for metric, values in values.items():
|
232 |
+
hlog(f"Pushing average {name} {metric} {sum(values) / len(values)} to tensorboard")
|
233 |
+
tb_context.add_scalar(f"{prefix}/{name}/{metric}", sum(values) / len(values), global_step=global_step)
|
234 |
+
|
235 |
+
tb_context.add_text("eval_config", obj_to_markdown(results), global_step=global_step)
|
236 |
+
# tb_context.add_text("eval_sizes", obj_to_markdown(sizes), global_step=global_step)
|
237 |
+
|
238 |
+
for task_name, task_details in details.items():
|
239 |
+
tb_context.add_text(
|
240 |
+
f"eval_details_{task_name}",
|
241 |
+
obj_to_markdown({"0": task_details[0], "1": task_details[1] if len(task_details) > 1 else {}}),
|
242 |
+
global_step=global_step,
|
243 |
+
)
|
244 |
+
|
245 |
+
# We are doing parallel evaluations of multiple checkpoints and recording the steps not in order
|
246 |
+
# This messes up with tensorboard, so the easiest is to rename files in the order of the checkpoints
|
247 |
+
# See: https://github.com/tensorflow/tensorboard/issues/5958
|
248 |
+
# But tensorboardX don't let us control the prefix of the files (only the suffix), so we need to do it ourselves before commiting the files
|
249 |
+
|
250 |
+
tb_context.close() # flushes the unfinished write operations
|
251 |
+
time.sleep(5)
|
252 |
+
files = os.listdir(output_dir_tb)
|
253 |
+
for file in files:
|
254 |
+
os.rename(os.path.join(output_dir_tb, file), os.path.join(output_dir_tb, f"{global_step:07d}_{file}"))
|
255 |
+
|
256 |
+
# Now we can push to the hub
|
257 |
+
tb_context.scheduler.trigger()
|
258 |
+
hlog(
|
259 |
+
f"Pushed to tensorboard at https://huggingface.co/tensorboard/{lighteval_config.logging.hub_repo_tensorboard}/"
|
260 |
+
f" at {output_dir_tb} and global_step {global_step}"
|
261 |
+
)
|
262 |
+
|
263 |
+
|
264 |
+
@htrack()
|
265 |
+
def main(args):
|
266 |
+
cache_dir = args.cache_dir or CACHE_DIR
|
267 |
+
check_env()
|
268 |
+
|
269 |
+
dist.initialize_torch_distributed()
|
270 |
+
|
271 |
+
with htrack_block("get config"):
|
272 |
+
if not args.checkpoint_config_path.endswith(".yaml"):
|
273 |
+
raise ValueError("The checkpoint path should point to a YAML file")
|
274 |
+
local_config_path = args.checkpoint_config_path
|
275 |
+
if args.checkpoint_config_path.startswith("s3:/"):
|
276 |
+
local_config_path = args.checkpoint_config_path.replace("s3:/", cache_dir)
|
277 |
+
with local_ranks_zero_first():
|
278 |
+
if os.environ.get("LOCAL_RANK", None) == "0":
|
279 |
+
os.makedirs(os.path.dirname(local_config_path), exist_ok=True)
|
280 |
+
fs_copy(args.checkpoint_config_path, local_config_path)
|
281 |
+
|
282 |
+
brrr_config: BrrrConfig = get_config_from_file(local_config_path, config_class=BrrrConfig, model_config_class=MistralConfig)
|
283 |
+
|
284 |
+
if args.lighteval_override:
|
285 |
+
local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
|
286 |
+
if args.lighteval_override.startswith("s3:/"):
|
287 |
+
local_override_path = args.lighteval_override.replace("s3:/", cache_dir)
|
288 |
+
with local_ranks_zero_first():
|
289 |
+
if os.environ.get("LOCAL_RANK", None) == "0":
|
290 |
+
os.makedirs(os.path.dirname(local_override_path), exist_ok=True)
|
291 |
+
fs_copy(args.lighteval_override, local_override_path)
|
292 |
+
lighteval_brrr_config: BrrrConfig = get_config_from_file(local_override_path, config_class=BrrrConfig)
|
293 |
+
lighteval_config = lighteval_brrr_config.lighteval
|
294 |
+
brrr_config.lighteval = lighteval_config
|
295 |
+
else:
|
296 |
+
local_override_path = ""
|
297 |
+
lighteval_config = brrr_config.lighteval
|
298 |
+
|
299 |
+
parallel_context = ParallelContext(
|
300 |
+
tensor_parallel_size=lighteval_config.parallelism.tp,
|
301 |
+
pipeline_parallel_size=lighteval_config.parallelism.pp,
|
302 |
+
data_parallel_size=lighteval_config.parallelism.dp,
|
303 |
+
)
|
304 |
+
|
305 |
+
evaluation_tracker = EvaluationTracker(token=TOKEN)
|
306 |
+
evaluation_tracker.general_config_logger.log_args_info(
|
307 |
+
num_fewshot_seeds=1,
|
308 |
+
override_batch_size=None,
|
309 |
+
max_samples=lighteval_config.tasks.max_samples,
|
310 |
+
job_id=os.environ.get("SLURM_JOB_ID", None),
|
311 |
+
config=brrr_config.as_dict(),
|
312 |
+
)
|
313 |
+
|
314 |
+
with htrack_block("Test all gather"):
|
315 |
+
hlog("Test gather tensor")
|
316 |
+
# Do a first NCCL sync to warmup and try to avoid Timeout after model/data loading
|
317 |
+
log_rank(
|
318 |
+
f"[TEST] Running NCCL sync for ranks {list(range(parallel_context.world_pg.size()))}",
|
319 |
+
logger=logger,
|
320 |
+
level=logging.WARNING,
|
321 |
+
group=parallel_context.dp_pg,
|
322 |
+
rank=0,
|
323 |
+
)
|
324 |
+
test_tensor = torch.tensor([dist.get_rank(parallel_context.world_pg)], device=torch.device("cuda"))
|
325 |
+
test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(parallel_context.world_pg.size())]
|
326 |
+
dist.all_gather(test_tensor_list, test_tensor, group=parallel_context.world_pg, async_op=False)
|
327 |
+
dist.barrier()
|
328 |
+
log_rank(
|
329 |
+
f"[TEST] NCCL sync for ranks {[t.item() for t in test_tensor_list]}",
|
330 |
+
logger=logger,
|
331 |
+
level=logging.WARNING,
|
332 |
+
group=parallel_context.dp_pg,
|
333 |
+
rank=0,
|
334 |
+
)
|
335 |
+
|
336 |
+
del test_tensor_list
|
337 |
+
del test_tensor
|
338 |
+
|
339 |
+
with htrack_block("Model loading"):
|
340 |
+
# We need to load the model in the main process first to avoid downloading the model multiple times
|
341 |
+
model = BRRRModel(
|
342 |
+
checkpoint_path=args.checkpoint_config_path.replace("config.yaml", ""),
|
343 |
+
model_args=brrr_config.model,
|
344 |
+
tokenizer=brrr_config.tokenizer,
|
345 |
+
parallel_context=parallel_context,
|
346 |
+
parallel_config=lighteval_config.parallelism,
|
347 |
+
lighteval_config=lighteval_config,
|
348 |
+
batch_size=lighteval_config.batch_size,
|
349 |
+
cache_dir=os.environ.get("HF_HOME", "/scratch"),
|
350 |
+
debug_one_layer_model=False,
|
351 |
+
s5cmd_path=args.s5cmd_path,
|
352 |
+
s5cmd_numworkers=args.s5cmd_numworkers,
|
353 |
+
s5cmd_concurrency=args.s5cmd_concurrency,
|
354 |
+
model_class=MistralForTraining
|
355 |
+
)
|
356 |
+
model_info = ModelInfo(model_name=f"{brrr_config.general.run}/{brrr_config.general.step}")
|
357 |
+
evaluation_tracker.general_config_logger.log_model_info(model_info)
|
358 |
+
|
359 |
+
with htrack_block("Tasks loading"):
|
360 |
+
with local_ranks_zero_first():
|
361 |
+
tasks_selection = lighteval_config.tasks.tasks
|
362 |
+
if lighteval_config.tasks.custom_tasks_file:
|
363 |
+
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
|
364 |
+
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
|
365 |
+
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]
|
366 |
+
|
367 |
+
task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
|
368 |
+
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
|
369 |
+
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
|
370 |
+
)
|
371 |
+
# Loading all the dataset in a distributed manner
|
372 |
+
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
|
373 |
+
|
374 |
+
evaluation_tracker.task_config_logger.log(task_dict)
|
375 |
+
|
376 |
+
hlog("Loading documents, and requests")
|
377 |
+
requests, docs = create_requests_from_tasks(
|
378 |
+
task_dict=task_dict,
|
379 |
+
fewshot_dict=few_shots_dict,
|
380 |
+
num_fewshot_seeds=lighteval_config.tasks.num_fewshot_seeds or 1,
|
381 |
+
lm=model,
|
382 |
+
max_samples=lighteval_config.tasks.max_samples,
|
383 |
+
evaluation_tracker=evaluation_tracker,
|
384 |
+
use_chat_template=False
|
385 |
+
)
|
386 |
+
|
387 |
+
with htrack_block("Setting seeds and waiting for all processes"):
|
388 |
+
hlog(f"setting seed to {1234} for random and numpy")
|
389 |
+
random.seed(1234)
|
390 |
+
np.random.seed(1234)
|
391 |
+
dist.barrier()
|
392 |
+
|
393 |
+
with htrack_block("Evaluation"):
|
394 |
+
hlog(f"Evaluate on {len(task_names_list)} tasks.")
|
395 |
+
evaluation_tracker = evaluate(
|
396 |
+
lm=model,
|
397 |
+
requests_dict=requests,
|
398 |
+
docs=docs,
|
399 |
+
task_dict=task_dict,
|
400 |
+
override_bs=lighteval_config.batch_size,
|
401 |
+
evaluation_tracker=evaluation_tracker,
|
402 |
+
)
|
403 |
+
|
404 |
+
if dist.get_rank(parallel_context.world_pg) == 0:
|
405 |
+
with htrack_block("Compiling and saving results"):
|
406 |
+
evaluation_tracker.general_config_logger.log_end_time()
|
407 |
+
evaluation_tracker.metrics_logger.aggregate(task_dict=task_dict, bootstrap_iters=1000)
|
408 |
+
evaluation_tracker.details_logger.aggregate()
|
409 |
+
|
410 |
+
if lighteval_config.logging.local_output_path:
|
411 |
+
evaluation_tracker.save(
|
412 |
+
output_dir=lighteval_config.logging.local_output_path,
|
413 |
+
push_results_to_hub=lighteval_config.logging.push_results_to_hub,
|
414 |
+
push_details_to_hub=lighteval_config.logging.push_details_to_hub,
|
415 |
+
public=False,
|
416 |
+
push_results_to_tensorboard=lighteval_config.logging.push_results_to_tensorboard,
|
417 |
+
)
|
418 |
+
|
419 |
+
if lighteval_config.logging.push_results_to_tensorboard:
|
420 |
+
push_results_to_tensorboard(
|
421 |
+
config=brrr_config,
|
422 |
+
results=evaluation_tracker.metrics_logger.metric_aggregated,
|
423 |
+
details=evaluation_tracker.details_logger.details,
|
424 |
+
)
|
425 |
+
if lighteval_config.wandb is not None:
|
426 |
+
push_results_to_wandb(
|
427 |
+
config=brrr_config,
|
428 |
+
results=evaluation_tracker.metrics_logger.metric_aggregated,
|
429 |
+
details=evaluation_tracker.details_logger.details,
|
430 |
+
)
|
431 |
+
|
432 |
+
final_dict = evaluation_tracker.generate_final_dict()
|
433 |
+
|
434 |
+
hlog(make_results_table(final_dict))
|
435 |
+
|
436 |
+
return final_dict
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
parser = get_parser()
|
441 |
+
args, unknowns = parser.parse_known_args()
|
442 |
+
main(args)
|
run_train.py
CHANGED
@@ -8,11 +8,11 @@ torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
|
|
8 |
```
|
9 |
"""
|
10 |
import argparse
|
|
|
11 |
|
12 |
-
from config_tiny_mistral import MistralConfig
|
13 |
from dataloader import get_dataloader
|
14 |
from modeling_mistral import MistralForTraining
|
15 |
-
from
|
16 |
|
17 |
|
18 |
def get_args():
|
|
|
8 |
```
|
9 |
"""
|
10 |
import argparse
|
11 |
+
from nanotron.trainer import DistributedTrainer
|
12 |
|
|
|
13 |
from dataloader import get_dataloader
|
14 |
from modeling_mistral import MistralForTraining
|
15 |
+
from config_tiny_mistral import MistralConfig
|
16 |
|
17 |
|
18 |
def get_args():
|