BryanW commited on
Commit
3182520
·
verified ·
1 Parent(s): adf4098

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/Dream/Dream_Baseline/eval_instruct/.gitignore +26 -0
  2. Prism/Dream/Dream_Baseline/eval_instruct/README.md +16 -0
  3. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/__init__.py +7 -0
  4. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/__main__.py +512 -0
  5. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/__init__.py +0 -0
  6. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/filter.py +56 -0
  7. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/group.py +115 -0
  8. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/instance.py +38 -0
  9. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/metrics.py +578 -0
  10. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/model.py +493 -0
  11. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/registry.py +196 -0
  12. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/samplers.py +232 -0
  13. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/task.py +1839 -0
  14. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/caching/__init__.py +0 -0
  15. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/caching/cache.py +59 -0
  16. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/__init__.py +0 -0
  17. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/archiver.py +174 -0
  18. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/decontaminate.py +166 -0
  19. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/janitor.py +328 -0
  20. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/evaluator.py +736 -0
  21. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/evaluator_utils.py +554 -0
  22. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/__init__.py +25 -0
  23. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/custom.py +17 -0
  24. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/decontamination.py +25 -0
  25. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/extraction.py +188 -0
  26. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/selection.py +61 -0
  27. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/transformation.py +56 -0
  28. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/__init__.py +2 -0
  29. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/evaluation_tracker.py +524 -0
  30. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/utils.py +149 -0
  31. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/wandb_logger.py +358 -0
  32. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/__init__.py +17 -0
  33. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/diffllm.py +563 -0
  34. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/dummy.py +41 -0
  35. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/hts_sampler.py +256 -0
  36. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/huggingface.py +1459 -0
  37. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/utils.py +731 -0
  38. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/verifier.py +155 -0
  39. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/prompts/__init__.py +128 -0
  40. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/README.md +165 -0
  41. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/__init__.py +669 -0
  42. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_5.yaml +19 -0
  43. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_5_instruct_noprefix.yaml +15 -0
  44. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_64.yaml +19 -0
  45. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_instruct.yaml +11 -0
  46. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_instruct_noprefix.yaml +15 -0
  47. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/sanitize_utils.py +121 -0
  48. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/utils.py +52 -0
  49. Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/utils.py +552 -0
  50. Prism/Dream/Dream_Baseline/eval_instruct/pyproject.toml +134 -0
Prism/Dream/Dream_Baseline/eval_instruct/.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env
2
+ *.pyc
3
+ output/
4
+ output5/
5
+ data/
6
+ lm_cache
7
+ .idea
8
+ build
9
+ dist
10
+ *.egg-info
11
+ venv
12
+ .venv/
13
+ .vscode/
14
+ temp
15
+ __pycache__
16
+ .ipynb_checkpoints
17
+ temp
18
+ test_logs/
19
+ # IPython
20
+ profile_default/
21
+ ipython_config.py
22
+ # don't track (the default location of) the cached requests
23
+ lm_eval/caching/.cache
24
+ # don't track files created by wandb
25
+ wandb
26
+ examples/wandb
Prism/Dream/Dream_Baseline/eval_instruct/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dream-Instruct Evaluation Toolkit
2
+ This toolkit contains the code Dream-Instruct models make use of for evaluation.
3
+
4
+ ## Quickstart
5
+ To install the toolkit, run:
6
+ ```
7
+ pip install -e ".[ifeval,math]"
8
+ ```
9
+
10
+ We provide a script to evaluate [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B):
11
+ ```
12
+ bash eval.sh
13
+ ```
14
+
15
+ ## Acknowledgement
16
+ This is a fork of [EleutherAI/lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main).
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from .evaluator import evaluate, simple_evaluate
5
+
6
+
7
+ __version__ = "0.4.8"
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/__main__.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import sys
6
+ from functools import partial
7
+ from typing import Union
8
+
9
+ from lm_eval import evaluator, utils
10
+ from lm_eval.evaluator import request_caching_arg_to_dict
11
+ from lm_eval.loggers import EvaluationTracker, WandbLogger
12
+ from lm_eval.tasks import TaskManager
13
+ from lm_eval.utils import (
14
+ handle_non_serializable,
15
+ make_table,
16
+ simple_parse_args_string,
17
+ )
18
+
19
+
20
+ def try_parse_json(value: str) -> Union[str, dict, None]:
21
+ if value is None:
22
+ return None
23
+ try:
24
+ return json.loads(value)
25
+ except json.JSONDecodeError:
26
+ if "{" in value:
27
+ raise argparse.ArgumentTypeError(
28
+ f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
29
+ )
30
+ return value
31
+
32
+
33
+ def _int_or_none_list_arg_type(
34
+ min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
35
+ ):
36
+ def parse_value(item):
37
+ item = item.strip().lower()
38
+ if item == "none":
39
+ return None
40
+ try:
41
+ return int(item)
42
+ except ValueError:
43
+ raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
44
+
45
+ items = [parse_value(v) for v in value.split(split_char)]
46
+ num_items = len(items)
47
+
48
+ if num_items == 1:
49
+ # Makes downstream handling the same for single and multiple values
50
+ items = items * max_len
51
+ elif num_items < min_len or num_items > max_len:
52
+ raise argparse.ArgumentTypeError(
53
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'"
54
+ )
55
+ elif num_items != max_len:
56
+ logging.warning(
57
+ f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
58
+ "Missing values will be filled with defaults."
59
+ )
60
+ default_items = [parse_value(v) for v in defaults.split(split_char)]
61
+ items.extend(
62
+ default_items[num_items:]
63
+ ) # extend items list with missing defaults
64
+
65
+ return items
66
+
67
+
68
+ def check_argument_types(parser: argparse.ArgumentParser):
69
+ """
70
+ Check to make sure all CLI args are typed, raises error if not
71
+ """
72
+ for action in parser._actions:
73
+ if action.dest != "help" and not action.const:
74
+ if action.type is None:
75
+ raise ValueError(
76
+ f"Argument '{action.dest}' doesn't have a type specified."
77
+ )
78
+ else:
79
+ continue
80
+
81
+
82
+ def setup_parser() -> argparse.ArgumentParser:
83
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
84
+ parser.add_argument(
85
+ "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
86
+ )
87
+ parser.add_argument(
88
+ "--tasks",
89
+ "-t",
90
+ default=None,
91
+ type=str,
92
+ metavar="task1,task2",
93
+ help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
94
+ )
95
+ parser.add_argument(
96
+ "--model_args",
97
+ "-a",
98
+ default="",
99
+ type=try_parse_json,
100
+ help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
101
+ )
102
+ parser.add_argument(
103
+ "--num_fewshot",
104
+ "-f",
105
+ type=int,
106
+ default=None,
107
+ metavar="N",
108
+ help="Number of examples in few-shot context",
109
+ )
110
+ parser.add_argument(
111
+ "--batch_size",
112
+ "-b",
113
+ type=str,
114
+ default=1,
115
+ metavar="auto|auto:N|N",
116
+ help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
117
+ )
118
+ parser.add_argument(
119
+ "--max_batch_size",
120
+ type=int,
121
+ default=None,
122
+ metavar="N",
123
+ help="Maximal batch size to try with --batch_size auto.",
124
+ )
125
+ parser.add_argument(
126
+ "--device",
127
+ type=str,
128
+ default=None,
129
+ help="Device to use (e.g. cuda, cuda:0, cpu).",
130
+ )
131
+ parser.add_argument(
132
+ "--output_path",
133
+ "-o",
134
+ default=None,
135
+ type=str,
136
+ metavar="DIR|DIR/file.json",
137
+ help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
138
+ )
139
+ parser.add_argument(
140
+ "--limit",
141
+ "-L",
142
+ type=float,
143
+ default=None,
144
+ metavar="N|0<N<1",
145
+ help="Limit the number of examples per task. "
146
+ "If <1, limit is a percentage of the total number of examples.",
147
+ )
148
+ parser.add_argument(
149
+ "--use_cache",
150
+ "-c",
151
+ type=str,
152
+ default=None,
153
+ metavar="DIR",
154
+ help="A path to a sqlite db file for caching model responses. `None` if not caching.",
155
+ )
156
+ parser.add_argument(
157
+ "--cache_requests",
158
+ type=str,
159
+ default=None,
160
+ choices=["true", "refresh", "delete"],
161
+ help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
162
+ )
163
+ parser.add_argument(
164
+ "--check_integrity",
165
+ action="store_true",
166
+ help="Whether to run the relevant part of the test suite for the tasks.",
167
+ )
168
+ parser.add_argument(
169
+ "--write_out",
170
+ "-w",
171
+ action="store_true",
172
+ default=False,
173
+ help="Prints the prompt for the first few documents.",
174
+ )
175
+ parser.add_argument(
176
+ "--log_samples",
177
+ "-s",
178
+ action="store_true",
179
+ default=False,
180
+ help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
181
+ )
182
+ parser.add_argument(
183
+ "--system_instruction",
184
+ type=str,
185
+ default=None,
186
+ help="System instruction to be used in the prompt",
187
+ )
188
+ parser.add_argument(
189
+ "--apply_chat_template",
190
+ type=str,
191
+ nargs="?",
192
+ const=True,
193
+ default=False,
194
+ help=(
195
+ "If True, apply chat template to the prompt. "
196
+ "Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
197
+ "To apply a specific template from the available list of templates, provide the template name as an argument. "
198
+ "E.g. `--apply_chat_template template_name`"
199
+ ),
200
+ )
201
+ parser.add_argument(
202
+ "--fewshot_as_multiturn",
203
+ action="store_true",
204
+ default=False,
205
+ help="If True, uses the fewshot as a multi-turn conversation",
206
+ )
207
+ parser.add_argument(
208
+ "--show_config",
209
+ action="store_true",
210
+ default=False,
211
+ help="If True, shows the the full config of all tasks at the end of the evaluation.",
212
+ )
213
+ parser.add_argument(
214
+ "--include_path",
215
+ type=str,
216
+ default=None,
217
+ metavar="DIR",
218
+ help="Additional path to include if there are external tasks to include.",
219
+ )
220
+ parser.add_argument(
221
+ "--gen_kwargs",
222
+ type=try_parse_json,
223
+ default=None,
224
+ help=(
225
+ "Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
226
+ """ e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
227
+ ),
228
+ )
229
+ parser.add_argument(
230
+ "--verbosity",
231
+ "-v",
232
+ type=str.upper,
233
+ default=None,
234
+ metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
235
+ help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
236
+ )
237
+ parser.add_argument(
238
+ "--wandb_args",
239
+ type=str,
240
+ default="",
241
+ help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
242
+ )
243
+ parser.add_argument(
244
+ "--wandb_config_args",
245
+ type=str,
246
+ default="",
247
+ help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
248
+ )
249
+ parser.add_argument(
250
+ "--hf_hub_log_args",
251
+ type=str,
252
+ default="",
253
+ help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
254
+ )
255
+ parser.add_argument(
256
+ "--predict_only",
257
+ "-x",
258
+ action="store_true",
259
+ default=False,
260
+ help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
261
+ )
262
+ default_seed_string = "0,1234,1234,1234"
263
+ parser.add_argument(
264
+ "--seed",
265
+ type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
266
+ default=default_seed_string, # for backward compatibility
267
+ help=(
268
+ "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
269
+ "Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
270
+ "respectively, or a single integer to set the same seed for all four.\n"
271
+ f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
272
+ "(for backward compatibility).\n"
273
+ "E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
274
+ "Here numpy's seed is not set since the second value is `None`.\n"
275
+ "E.g, `--seed 42` sets all four seeds to 42."
276
+ ),
277
+ )
278
+ parser.add_argument(
279
+ "--trust_remote_code",
280
+ action="store_true",
281
+ help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
282
+ )
283
+ parser.add_argument(
284
+ "--confirm_run_unsafe_code",
285
+ action="store_true",
286
+ help="Confirm that you understand the risks of running unsafe code for tasks that require it",
287
+ )
288
+ parser.add_argument(
289
+ "--metadata",
290
+ type=json.loads,
291
+ default=None,
292
+ help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
293
+ )
294
+ return parser
295
+
296
+
297
+ def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
298
+ check_argument_types(parser)
299
+ return parser.parse_args()
300
+
301
+
302
+ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
303
+ if not args:
304
+ # we allow for args to be passed externally, else we parse them ourselves
305
+ parser = setup_parser()
306
+ args = parse_eval_args(parser)
307
+
308
+ if args.wandb_args:
309
+ wandb_args_dict = simple_parse_args_string(args.wandb_args)
310
+ wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
311
+ wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
312
+
313
+ utils.setup_logging(args.verbosity)
314
+ eval_logger = logging.getLogger(__name__)
315
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
316
+
317
+ # update the evaluation tracker args with the output path and the HF token
318
+ if args.output_path:
319
+ args.hf_hub_log_args += f",output_path={args.output_path}"
320
+ if os.environ.get("HF_TOKEN", None):
321
+ args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
322
+ evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
323
+ evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
324
+
325
+ if args.predict_only:
326
+ args.log_samples = True
327
+ if (args.log_samples or args.predict_only) and not args.output_path:
328
+ raise ValueError(
329
+ "Specify --output_path if providing --log_samples or --predict_only"
330
+ )
331
+
332
+ if args.fewshot_as_multiturn and args.apply_chat_template is False:
333
+ raise ValueError(
334
+ "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
335
+ )
336
+
337
+ if args.include_path is not None:
338
+ eval_logger.info(f"Including path: {args.include_path}")
339
+ metadata = (
340
+ simple_parse_args_string(args.model_args)
341
+ if isinstance(args.model_args, str)
342
+ else args.model_args
343
+ if isinstance(args.model_args, dict)
344
+ else {}
345
+ ) | (
346
+ args.metadata
347
+ if isinstance(args.metadata, dict)
348
+ else simple_parse_args_string(args.metadata)
349
+ )
350
+
351
+ task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
352
+
353
+ if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
354
+ eval_logger.warning(
355
+ "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
356
+ )
357
+
358
+ if args.limit:
359
+ eval_logger.warning(
360
+ " --limit SHOULD ONLY BE USED FOR TESTING."
361
+ "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
362
+ )
363
+
364
+ if args.tasks is None:
365
+ eval_logger.error("Need to specify task to evaluate.")
366
+ sys.exit()
367
+ elif args.tasks == "list":
368
+ print(task_manager.list_all_tasks())
369
+ sys.exit()
370
+ elif args.tasks == "list_groups":
371
+ print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
372
+ sys.exit()
373
+ elif args.tasks == "list_tags":
374
+ print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
375
+ sys.exit()
376
+ elif args.tasks == "list_subtasks":
377
+ print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
378
+ sys.exit()
379
+ else:
380
+ if os.path.isdir(args.tasks):
381
+ import glob
382
+
383
+ task_names = []
384
+ yaml_path = os.path.join(args.tasks, "*.yaml")
385
+ for yaml_file in glob.glob(yaml_path):
386
+ config = utils.load_yaml_config(yaml_file)
387
+ task_names.append(config)
388
+ else:
389
+ task_list = args.tasks.split(",")
390
+ task_names = task_manager.match_tasks(task_list)
391
+ for task in [task for task in task_list if task not in task_names]:
392
+ if os.path.isfile(task):
393
+ config = utils.load_yaml_config(task)
394
+ task_names.append(config)
395
+ task_missing = [
396
+ task for task in task_list if task not in task_names and "*" not in task
397
+ ] # we don't want errors if a wildcard ("*") task name was used
398
+
399
+ if task_missing:
400
+ missing = ", ".join(task_missing)
401
+ eval_logger.error(
402
+ f"Tasks were not found: {missing}\n"
403
+ f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
404
+ )
405
+ raise ValueError(
406
+ f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
407
+ )
408
+
409
+ # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
410
+ if args.trust_remote_code:
411
+ eval_logger.info(
412
+ "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
413
+ )
414
+ # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
415
+ # because it's already been determined based on the prior env var before launching our
416
+ # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
417
+ import datasets
418
+
419
+ datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
420
+
421
+ args.model_args = args.model_args + ",trust_remote_code=True"
422
+ eval_logger.info(
423
+ f"Selected Tasks: {task_names}"
424
+ ) if eval_logger.getEffectiveLevel() >= logging.INFO else print(
425
+ f"Selected Tasks: {task_names}"
426
+ )
427
+
428
+ request_caching_args = request_caching_arg_to_dict(
429
+ cache_requests=args.cache_requests
430
+ )
431
+
432
+ results = evaluator.simple_evaluate(
433
+ model=args.model,
434
+ model_args=args.model_args,
435
+ tasks=task_names,
436
+ num_fewshot=args.num_fewshot,
437
+ batch_size=args.batch_size,
438
+ max_batch_size=args.max_batch_size,
439
+ device=args.device,
440
+ use_cache=args.use_cache,
441
+ limit=args.limit,
442
+ check_integrity=args.check_integrity,
443
+ write_out=args.write_out,
444
+ log_samples=args.log_samples,
445
+ evaluation_tracker=evaluation_tracker,
446
+ system_instruction=args.system_instruction,
447
+ apply_chat_template=args.apply_chat_template,
448
+ fewshot_as_multiturn=args.fewshot_as_multiturn,
449
+ gen_kwargs=args.gen_kwargs,
450
+ task_manager=task_manager,
451
+ predict_only=args.predict_only,
452
+ random_seed=args.seed[0],
453
+ numpy_random_seed=args.seed[1],
454
+ torch_random_seed=args.seed[2],
455
+ fewshot_random_seed=args.seed[3],
456
+ confirm_run_unsafe_code=args.confirm_run_unsafe_code,
457
+ metadata=metadata,
458
+ **request_caching_args,
459
+ )
460
+
461
+ if results is not None:
462
+ if args.log_samples:
463
+ samples = results.pop("samples")
464
+ dumped = json.dumps(
465
+ results, indent=2, default=handle_non_serializable, ensure_ascii=False
466
+ )
467
+ if args.show_config:
468
+ print(dumped)
469
+
470
+ batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
471
+
472
+ # Add W&B logging
473
+ if args.wandb_args:
474
+ try:
475
+ wandb_logger.post_init(results)
476
+ wandb_logger.log_eval_result()
477
+ if args.log_samples:
478
+ wandb_logger.log_eval_samples(samples)
479
+ except Exception as e:
480
+ eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
481
+
482
+ evaluation_tracker.save_results_aggregated(
483
+ results=results, samples=samples if args.log_samples else None
484
+ )
485
+
486
+ if args.log_samples:
487
+ for task_name, config in results["configs"].items():
488
+ evaluation_tracker.save_results_samples(
489
+ task_name=task_name, samples=samples[task_name]
490
+ )
491
+
492
+ if (
493
+ evaluation_tracker.push_results_to_hub
494
+ or evaluation_tracker.push_samples_to_hub
495
+ ):
496
+ evaluation_tracker.recreate_metadata_card()
497
+
498
+ print(
499
+ f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
500
+ f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
501
+ )
502
+ print(make_table(results))
503
+ if "groups" in results:
504
+ print(make_table(results, "groups"))
505
+
506
+ if args.wandb_args:
507
+ # Tear down wandb run once all the logging is done.
508
+ wandb_logger.run.finish()
509
+
510
+
511
+ if __name__ == "__main__":
512
+ cli_evaluate()
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/__init__.py ADDED
File without changes
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/filter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Iterable, List, Union
4
+
5
+ from lm_eval.api.instance import Instance
6
+
7
+
8
+ class Filter(ABC):
9
+ """
10
+ Filter classes operate on a per-task level.
11
+ They take all model outputs (`instance.resps` for all `task.instances`)
12
+ across all instances of a task, and perform operations.
13
+ In a single run, one can configure any number of separate filters or lists of filters.
14
+
15
+ """
16
+
17
+ def __init__(self, **kwargs) -> None:
18
+ """
19
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
20
+ """
21
+
22
+ @abstractmethod
23
+ def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
24
+ """
25
+ Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
26
+ Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
27
+ if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
28
+ [<filtered resps for instance 0>, <filtered resps for instance 1>]
29
+ """
30
+ return resps
31
+
32
+
33
+ @dataclass
34
+ class FilterEnsemble:
35
+ """
36
+ FilterEnsemble creates a pipeline applying multiple filters.
37
+ Its intended usage is to stack multiple post-processing steps in order.
38
+ `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
39
+ pipeline separately.
40
+ """
41
+
42
+ name: str
43
+ filters: List[Callable[[], Filter]]
44
+
45
+ def apply(self, instances: List[Instance]) -> None:
46
+ resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
47
+ resps, docs = list(resps), list(docs)
48
+
49
+ for f in self.filters:
50
+ # apply filters in sequence
51
+ resps = f().apply(resps, docs)
52
+
53
+ # add the end results after filtering to filtered_requests of their respective source instances.
54
+ # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
55
+ for inst, resp in zip(instances, resps):
56
+ inst.filtered_resps[self.name] = resp
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/group.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from dataclasses import asdict, dataclass
3
+ from inspect import getsource
4
+ from typing import Any, Callable, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AggMetricConfig(dict):
9
+ metric: Optional[str] = None
10
+ aggregation: Optional[str] = "mean"
11
+ weight_by_size: Optional[str] = False
12
+ # list of filter names which should be incorporated into the aggregated metric.
13
+ filter_list: Optional[Union[str, list]] = "none"
14
+
15
+ def __post_init__(self):
16
+ if self.aggregation != "mean" and not callable(self.aggregation):
17
+ raise ValueError(
18
+ f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
19
+ )
20
+
21
+ if isinstance(self.filter_list, str):
22
+ self.filter_list = [self.filter_list]
23
+
24
+
25
+ @dataclass
26
+ class GroupConfig(dict):
27
+ group: Optional[str] = None
28
+ group_alias: Optional[str] = None
29
+ task: Optional[Union[str, list]] = None
30
+ aggregate_metric_list: Optional[
31
+ Union[List[AggMetricConfig], AggMetricConfig, dict]
32
+ ] = None
33
+ metadata: Optional[dict] = (
34
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
35
+ )
36
+
37
+ def __getitem__(self, item):
38
+ return getattr(self, item)
39
+
40
+ def __setitem__(self, item, value):
41
+ return setattr(self, item, value)
42
+
43
+ def __post_init__(self):
44
+ if self.aggregate_metric_list is not None:
45
+ if isinstance(self.aggregate_metric_list, dict):
46
+ self.aggregate_metric_list = [self.aggregate_metric_list]
47
+
48
+ self.aggregate_metric_list = [
49
+ AggMetricConfig(**item) if isinstance(item, dict) else item
50
+ for item in self.aggregate_metric_list
51
+ ]
52
+
53
+ def to_dict(self, keep_callable: bool = False) -> dict:
54
+ """dumps the current config as a dictionary object, as a printable format.
55
+ null fields will not be printed.
56
+ Used for dumping results alongside full task configuration
57
+
58
+ :return: dict
59
+ A printable dictionary version of the TaskConfig object.
60
+
61
+ # TODO: should any default value in the TaskConfig not be printed?
62
+ """
63
+ cfg_dict = asdict(self)
64
+ # remove values that are `None`
65
+ for k, v in list(cfg_dict.items()):
66
+ if callable(v):
67
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
68
+ return cfg_dict
69
+
70
+ def serialize_function(
71
+ self, value: Union[Callable, str], keep_callable=False
72
+ ) -> Union[Callable, str]:
73
+ """Serializes a given function or string.
74
+
75
+ If 'keep_callable' is True, the original callable is returned.
76
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
77
+ """
78
+ if keep_callable:
79
+ return value
80
+ else:
81
+ try:
82
+ return getsource(value)
83
+ except (TypeError, OSError):
84
+ return str(value)
85
+
86
+
87
+ class ConfigurableGroup(abc.ABC):
88
+ def __init__(
89
+ self,
90
+ config: Optional[dict] = None,
91
+ ) -> None:
92
+ self._config = GroupConfig(**config)
93
+
94
+ @property
95
+ def group(self):
96
+ return self._config.group
97
+
98
+ @property
99
+ def group_alias(self):
100
+ return self._config.group_alias
101
+
102
+ @property
103
+ def version(self):
104
+ return self._config.version
105
+
106
+ @property
107
+ def config(self):
108
+ return self._config.to_dict()
109
+
110
+ @property
111
+ def group_name(self) -> Any:
112
+ return self._config.group
113
+
114
+ def __repr__(self):
115
+ return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/instance.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal, Optional, Tuple
3
+
4
+
5
+ OutputType = Literal[
6
+ "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
7
+ ]
8
+
9
+
10
+ @dataclass
11
+ class Instance:
12
+ request_type: OutputType
13
+ doc: dict
14
+ arguments: tuple
15
+ idx: int
16
+ metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
17
+ default_factory=lambda: (None, None, None)
18
+ )
19
+ resps: list = field(default_factory=list)
20
+ filtered_resps: dict = field(default_factory=dict)
21
+
22
+ # initialized after init
23
+ task_name: Optional[str] = None
24
+ doc_id: Optional[int] = None
25
+ repeats: Optional[int] = None
26
+
27
+ def __post_init__(self) -> None:
28
+ # unpack metadata field
29
+ self.task_name, self.doc_id, self.repeats = self.metadata
30
+
31
+ @property
32
+ def args(self):
33
+ """
34
+ Returns (string,) where `string` is the string to calculate loglikelihood over
35
+ """
36
+ return (
37
+ self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
38
+ )
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/metrics.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import random
4
+ import re
5
+ import string
6
+ from collections.abc import Iterable
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import sacrebleu
11
+
12
+ from lm_eval.api.registry import register_aggregation, register_metric
13
+
14
+
15
+ eval_logger = logging.getLogger(__name__)
16
+
17
+
18
+ # Register Aggregations First
19
+ @register_aggregation("bypass")
20
+ def bypass_agg(arr):
21
+ return 999
22
+
23
+
24
+ @register_aggregation("nanmean")
25
+ def nanmean(arr):
26
+ if len(arr) == 0 or all(np.isnan(arr)):
27
+ return np.nan
28
+ return np.nanmean(arr)
29
+
30
+
31
+ @register_aggregation("mean")
32
+ def mean(arr):
33
+ return sum(arr) / len(arr)
34
+
35
+
36
+ @register_aggregation("median")
37
+ def median(arr):
38
+ return arr[len(arr) // 2]
39
+
40
+
41
+ # Certain metrics must be calculated across all documents in a benchmark.
42
+ # We use them as aggregation metrics, paired with no-op passthrough metric fns.
43
+ @register_aggregation("perplexity")
44
+ def perplexity(items):
45
+ return math.exp(-mean(items))
46
+
47
+
48
+ @register_aggregation("weighted_perplexity")
49
+ def weighted_perplexity(items):
50
+ return math.exp(-weighted_mean(items))
51
+
52
+
53
+ @register_aggregation("bits_per_byte")
54
+ def bits_per_byte(items):
55
+ return -weighted_mean(items) / math.log(2)
56
+
57
+
58
+ @register_aggregation("f1")
59
+ def f1_score(items):
60
+ from sklearn.metrics import f1_score
61
+
62
+ unzipped_list = list(zip(*items))
63
+ golds = unzipped_list[0]
64
+ preds = unzipped_list[1]
65
+ fscore = f1_score(golds, preds)
66
+
67
+ return np.max(fscore)
68
+
69
+
70
+ @register_aggregation("matthews_corrcoef")
71
+ def matthews_corrcoef(items):
72
+ from sklearn.metrics import matthews_corrcoef
73
+
74
+ unzipped_list = list(zip(*items))
75
+ golds = unzipped_list[0]
76
+ preds = unzipped_list[1]
77
+ return matthews_corrcoef(golds, preds)
78
+
79
+
80
+ @register_aggregation("bleu")
81
+ def bleu(items):
82
+ """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
83
+ for evaluating a generated sentence to a reference sentence. It counts matching
84
+ n-grams in the candidate translation to n-grams in the reference text, where
85
+ 1-gram or unigram would be each token and a bigram comparison would be each
86
+ word pair. The comparison is made regardless of word order
87
+ Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
88
+ Paper: https://www.aclweb.org/anthology/P02-1040/
89
+
90
+ Higher is better
91
+ """
92
+ refs = list(zip(*items))[0]
93
+ preds = list(zip(*items))[1]
94
+ refs, preds = _sacreformat(refs, preds)
95
+ return sacrebleu.corpus_bleu(preds, refs).score
96
+
97
+
98
+ @register_aggregation("chrf")
99
+ def chrf(items):
100
+ """chrF++ is a tool for automatic evaluation of machine translation output
101
+ based on character n-gram precision and recall enhanced with word n-grams.
102
+ Source: https://github.com/m-popovic/chrF
103
+ Paper: https://www.aclweb.org/anthology/W15-3049.pdf
104
+
105
+ Higher is better # TODO I think
106
+ """
107
+ refs = list(zip(*items))[0]
108
+ preds = list(zip(*items))[1]
109
+ refs, preds = _sacreformat(refs, preds)
110
+ return sacrebleu.corpus_chrf(preds, refs).score
111
+
112
+
113
+ @register_aggregation("ter")
114
+ def ter(items):
115
+ """Translation Error Rate is an error metric for machine translation that
116
+ measures the number of edits required to change a system output into one
117
+ of the references
118
+ Source: http://www.cs.umd.edu/~snover/tercom/
119
+ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
120
+
121
+ Lower is better
122
+ """
123
+ refs = list(zip(*items))[0]
124
+ preds = list(zip(*items))[1]
125
+ refs, preds = _sacreformat(refs, preds)
126
+ return sacrebleu.corpus_ter(preds, refs).score
127
+
128
+
129
+ @register_aggregation("brier_score")
130
+ def brier_score(items): # This is a passthrough function
131
+ gold, predictions = list(zip(*items))
132
+ bs, num_class = np.array(predictions).shape
133
+
134
+ gold = list(gold)
135
+ gold_one_hot = np.eye(num_class)[gold]
136
+ return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
137
+
138
+
139
+ @register_metric(
140
+ metric="brier_score",
141
+ higher_is_better=False,
142
+ output_type=["multiple_choice"],
143
+ aggregation="brier_score",
144
+ )
145
+ def brier_score_fn(items): # This is a passthrough function
146
+ return items
147
+
148
+
149
+ @register_metric(
150
+ metric="acc",
151
+ higher_is_better=True,
152
+ output_type=["loglikelihood", "multiple_choice"],
153
+ aggregation="mean",
154
+ )
155
+ def acc_fn(items): # This is a passthrough function
156
+ return items
157
+
158
+
159
+ @register_metric(
160
+ metric="acc_norm",
161
+ higher_is_better=True,
162
+ output_type=["loglikelihood", "multiple_choice"],
163
+ aggregation="mean",
164
+ )
165
+ def acc_norm_fn(items): # This is a passthrough function
166
+ return items
167
+
168
+
169
+ @register_metric(
170
+ metric="acc_mutual_info",
171
+ higher_is_better=True,
172
+ output_type="multiple_choice",
173
+ aggregation="mean",
174
+ )
175
+ def acc_mutual_info_fn(items): # This is a passthrough function
176
+ return items
177
+
178
+
179
+ ### the code used in the `exact_match_hf_evaluate` function is ported from
180
+ ### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
181
+ ### which is under the apache license.
182
+
183
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
184
+
185
+ # Licensed under the Apache License, Version 2.0 (the "License");
186
+ # you may not use this file except in compliance with the License.
187
+ # You may obtain a copy of the License at
188
+
189
+ # http://www.apache.org/licenses/LICENSE-2.0
190
+
191
+
192
+ # Unless required by applicable law or agreed to in writing, software
193
+ # distributed under the License is distributed on an "AS IS" BASIS,
194
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
195
+ # See the License for the specific language governing permissions and
196
+ # limitations under the License.
197
+ def exact_match_hf_evaluate(
198
+ predictions,
199
+ references,
200
+ regexes_to_ignore=None,
201
+ ignore_case=False,
202
+ ignore_punctuation=False,
203
+ ignore_numbers=False,
204
+ ):
205
+ if regexes_to_ignore is not None:
206
+ for s in regexes_to_ignore:
207
+ predictions = np.array([re.sub(s, "", x) for x in predictions])
208
+ references = np.array([re.sub(s, "", x) for x in references])
209
+ else:
210
+ predictions = np.asarray(predictions)
211
+ references = np.asarray(references)
212
+
213
+ if ignore_case:
214
+ predictions = np.char.lower(predictions)
215
+ references = np.char.lower(references)
216
+
217
+ if ignore_punctuation:
218
+ repl_table = string.punctuation.maketrans("", "", string.punctuation)
219
+ predictions = np.char.translate(predictions, table=repl_table)
220
+ references = np.char.translate(references, table=repl_table)
221
+
222
+ if ignore_numbers:
223
+ repl_table = string.digits.maketrans("", "", string.digits)
224
+ predictions = np.char.translate(predictions, table=repl_table)
225
+ references = np.char.translate(references, table=repl_table)
226
+
227
+ score_list = predictions == references
228
+
229
+ return {"exact_match": np.mean(score_list)}
230
+
231
+
232
+ ###
233
+
234
+
235
+ @register_metric(
236
+ metric="exact_match",
237
+ higher_is_better=True,
238
+ output_type="generate_until",
239
+ aggregation="mean",
240
+ )
241
+ def exact_match_fn(**kwargs):
242
+ return exact_match_hf_evaluate(**kwargs)
243
+
244
+
245
+ @register_metric(
246
+ metric="perplexity",
247
+ higher_is_better=False,
248
+ output_type="loglikelihood",
249
+ aggregation="perplexity",
250
+ )
251
+ def perplexity_fn(items): # This is a passthrough function
252
+ return items
253
+
254
+
255
+ @register_metric(
256
+ metric="word_perplexity",
257
+ higher_is_better=False,
258
+ output_type="loglikelihood_rolling",
259
+ aggregation="weighted_perplexity",
260
+ )
261
+ def word_perplexity_fn(items): # This is a passthrough function
262
+ return items
263
+
264
+
265
+ @register_metric(
266
+ metric="byte_perplexity",
267
+ higher_is_better=False,
268
+ output_type="loglikelihood_rolling",
269
+ aggregation="weighted_perplexity",
270
+ )
271
+ def byte_perplexity_fn(items): # This is a passthrough function
272
+ return items
273
+
274
+
275
+ @register_metric(
276
+ metric="bits_per_byte",
277
+ higher_is_better=False,
278
+ output_type="loglikelihood_rolling",
279
+ aggregation="bits_per_byte",
280
+ )
281
+ def bits_per_byte_fn(items): # This is a passthrough function
282
+ return items
283
+
284
+
285
+ def pop_stddev(arr):
286
+ mu = mean(arr)
287
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
288
+
289
+
290
+ def sample_stddev(arr):
291
+ mu = mean(arr)
292
+ return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
293
+
294
+
295
+ def mean_stderr(arr):
296
+ return sample_stddev(arr) / math.sqrt(len(arr))
297
+
298
+
299
+ @register_metric(
300
+ metric="bypass",
301
+ higher_is_better=True,
302
+ output_type=["loglikelihood", "multiple_choice", "generate_until"],
303
+ aggregation="bypass",
304
+ )
305
+ def bypass(items):
306
+ return None
307
+
308
+
309
+ @register_metric(
310
+ metric="mcc",
311
+ higher_is_better=True,
312
+ output_type="multiple_choice",
313
+ aggregation="matthews_corrcoef",
314
+ )
315
+ def mcc_fn(items): # This is a passthrough function
316
+ return items
317
+
318
+
319
+ @register_metric(
320
+ metric="f1",
321
+ higher_is_better=True,
322
+ output_type="multiple_choice",
323
+ aggregation="f1",
324
+ )
325
+ def f1_fn(items): # This is a passthrough function
326
+ return items
327
+
328
+
329
+ @register_metric(
330
+ metric="bleu",
331
+ higher_is_better=True,
332
+ output_type="generate_until",
333
+ aggregation="bleu",
334
+ )
335
+ def bleu_fn(items): # This is a passthrough function
336
+ return items
337
+
338
+
339
+ @register_metric(
340
+ metric="chrf",
341
+ higher_is_better=True,
342
+ output_type="generate_until",
343
+ aggregation="chrf",
344
+ )
345
+ def chrf_fn(items): # This is a passthrough function
346
+ return items
347
+
348
+
349
+ @register_metric(
350
+ metric="ter",
351
+ higher_is_better=True,
352
+ output_type="generate_until",
353
+ aggregation="ter",
354
+ )
355
+ def ter_fn(items): # This is a passthrough function
356
+ return items
357
+
358
+
359
+ @register_metric(
360
+ metric="acc_all",
361
+ higher_is_better=True,
362
+ output_type="loglikelihood",
363
+ aggregation="mean",
364
+ )
365
+ def acc_all(items):
366
+ # Only count as correct if all answers are labeled correctly for each question
367
+ question_scoring_dict = {}
368
+ preds = list(zip(*items))[0]
369
+ docs = list(zip(*items))[1]
370
+
371
+ for doc, pred in zip(docs, preds):
372
+ paragraph_id = doc["idx"]["paragraph"]
373
+ question_id = doc["idx"]["question"]
374
+ if (paragraph_id, question_id) not in question_scoring_dict:
375
+ question_scoring_dict[(paragraph_id, question_id)] = []
376
+
377
+ gold_label = doc["label"] == 1
378
+
379
+ question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
380
+ acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
381
+ return acc
382
+
383
+
384
+ def acc_all_stderr(items):
385
+ # Only count as correct if all answers are labeled correctly for each question
386
+ question_scoring_dict = {}
387
+ preds = list(zip(*items))[0]
388
+ docs = list(zip(*items))[1]
389
+
390
+ for doc, pred in zip(docs, preds):
391
+ question_id = doc["idx"]["question"]
392
+ if question_id not in question_scoring_dict:
393
+ question_scoring_dict[question_id] = []
394
+
395
+ gold_label = doc["label"] == 1
396
+ question_scoring_dict[question_id].append(gold_label == pred)
397
+
398
+ acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
399
+ return acc
400
+
401
+
402
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
403
+ """Compute max metric between prediction and each ground truth."""
404
+ scores_for_ground_truths = []
405
+ for ground_truth in ground_truths:
406
+ score = metric_fn(prediction, ground_truth)
407
+ scores_for_ground_truths.append(score)
408
+ return max(scores_for_ground_truths)
409
+
410
+
411
+ def weighted_mean(items):
412
+ a, b = zip(*items)
413
+ return sum(a) / sum(b)
414
+
415
+
416
+ def is_non_str_iterable(obj):
417
+ return isinstance(obj, Iterable) and not isinstance(obj, str)
418
+
419
+
420
+ def _sacreformat(refs, preds):
421
+ """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
422
+ # Sacrebleu expects (List[str], List[List[str])
423
+ # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
424
+
425
+ # Note [ref1_stream] is the first reference for each pred.
426
+ # So lists are size N and (M, N) for N preds and M possible refs for each pred
427
+ # This is a different order of dimensions that I would expect
428
+
429
+ # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
430
+ # Must become List[List[str]] with the inner list corresponding to preds
431
+ if not is_non_str_iterable(refs):
432
+ refs = list(refs)
433
+ if not is_non_str_iterable(refs[0]):
434
+ refs = [[ref] for ref in refs]
435
+ refs = list(zip(*refs))
436
+ # Note the number of refs in each ref list much match the number of preds
437
+
438
+ # We expect preds to be List[str] or List[List[str]]. Must become List[str]
439
+ if not is_non_str_iterable(preds):
440
+ preds = list(preds)
441
+ if is_non_str_iterable(preds[0]):
442
+ assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
443
+ preds = [pred[0] for pred in preds]
444
+
445
+ return refs, preds
446
+
447
+
448
+ # stderr stuff
449
+
450
+
451
+ class _bootstrap_internal:
452
+ def __init__(self, f, n) -> None:
453
+ self.f = f
454
+ self.n = n
455
+
456
+ def __call__(self, v):
457
+ i, xs = v
458
+ rnd = random.Random()
459
+ rnd.seed(i)
460
+ res = []
461
+ for _ in range(self.n):
462
+ res.append(self.f(rnd.choices(xs, k=len(xs))))
463
+ return res
464
+
465
+
466
+ def bootstrap_stderr(f, xs, iters):
467
+ import multiprocessing as mp
468
+
469
+ pool = mp.Pool(mp.cpu_count())
470
+ # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
471
+ # equivalent to stderr calculated without Bessel's correction in the stddev.
472
+ # Unfortunately, I haven't been able to figure out what the right correction is
473
+ # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
474
+ # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
475
+ # Thankfully, shouldn't matter because our samples are pretty big usually anyways
476
+ res = []
477
+ chunk_size = min(1000, iters)
478
+ from tqdm import tqdm
479
+
480
+ print("bootstrapping for stddev:", f.__name__)
481
+ for bootstrap in tqdm(
482
+ pool.imap(
483
+ _bootstrap_internal(f, chunk_size),
484
+ [(i, xs) for i in range(iters // chunk_size)],
485
+ ),
486
+ total=iters // chunk_size,
487
+ ):
488
+ # sample w replacement
489
+ res.extend(bootstrap)
490
+
491
+ pool.close()
492
+ return sample_stddev(res)
493
+
494
+
495
+ def stderr_for_metric(metric, bootstrap_iters: int):
496
+ if bootstrap_iters <= 0:
497
+ # return no function (don't compute stderr) if bootstrap iters = 0
498
+ return None
499
+
500
+ bootstrappable = [
501
+ median,
502
+ matthews_corrcoef,
503
+ f1_score,
504
+ perplexity,
505
+ bleu,
506
+ chrf,
507
+ ter,
508
+ nanmean,
509
+ ]
510
+
511
+ if metric in bootstrappable:
512
+ return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
513
+
514
+ stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
515
+
516
+ return stderr.get(metric, None)
517
+
518
+
519
+ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
520
+ # Used to aggregate bootstrapped stderrs across subtasks in a group,
521
+ # when we are weighting by the size of each subtask.
522
+ #
523
+
524
+ assert len(stderrs) == len(sizes)
525
+
526
+ # formula source: https://en.wikipedia.org/wiki/Pooled_variance
527
+ # and: https://stats.stackexchange.com/a/4841331
528
+ # this empirically seems to match running `stderr_for_metric` on all instances
529
+ # from the subtasks concatenated with each other.
530
+ pooled_sample_var = (
531
+ sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
532
+ ) / (sum(sizes) - len(sizes))
533
+
534
+ return np.sqrt(pooled_sample_var / sum(sizes))
535
+
536
+
537
+ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
538
+ assert metrics is not None, (
539
+ "Need to pass a list of each subtask's metric for this stderr aggregation"
540
+ )
541
+ assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
542
+
543
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
544
+ # This formula depends on sample means.
545
+ # removed because it seems to give erroneously huge stderrs for groupings of tasks
546
+ # and does not seem to match up with bootstrap-calculated stderrs for groups.
547
+
548
+ ### don't use this unless a statistician has told you it's the right thing to do ###
549
+
550
+ # accumulators: we'll aggregate pairwise N - 1 times
551
+ variance = stderrs[0] ** 2
552
+ curr_size = sizes[0]
553
+ curr_score = metrics[0]
554
+
555
+ for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
556
+ curr_score = ((curr_score * curr_size) + (score * size)) / (
557
+ curr_size + size
558
+ ) # NOTE: this assumes our aggregation fn is "mean"
559
+
560
+ variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
561
+ curr_size + size - 1
562
+ ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
563
+ curr_score - score
564
+ ) ** 2
565
+
566
+ return np.sqrt(variance)
567
+
568
+
569
+ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
570
+ # A helper function that is used to aggregate
571
+ # subtask scores cross-task.
572
+ # TODO: does not hold for non-mean aggregations
573
+ if not weight_by_size:
574
+ sizes = [1] * len(sizes)
575
+
576
+ assert len(metrics) == len(sizes)
577
+
578
+ return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/model.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
7
+
8
+ import transformers
9
+ from sqlitedict import SqliteDict
10
+ from tqdm import tqdm
11
+
12
+ from lm_eval import utils
13
+
14
+
15
+ eval_logger = logging.getLogger(__name__)
16
+
17
+ T = TypeVar("T", bound="LM")
18
+
19
+
20
+ class LM(abc.ABC):
21
+ def __init__(self) -> None:
22
+ """Defines the interface that should be implemented by all LM subclasses.
23
+ LMs are assumed to take text (strings) as input and yield strings as output
24
+ (inputs/outputs should be tokenization-agnostic.)
25
+
26
+ """
27
+ # set rank and world size to a single process, by default.
28
+ self._rank = 0
29
+ self._world_size = 1
30
+ self.cache_hook = CacheHook(None)
31
+
32
+ @abc.abstractmethod
33
+ def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
34
+ """Compute log-likelihood of generating a continuation from a context.
35
+ Downstream tasks should attempt to use loglikelihood instead of other
36
+ LM calls whenever possible.
37
+
38
+ :param requests: list[Instance]
39
+ A list of Instance objects, with property `args` which returns a tuple (context, continuation).
40
+ `context: str`
41
+ Context string. Implementations of LM must be able to handle an
42
+ empty context string.
43
+ `continuation: str`
44
+ The continuation over which log likelihood will be calculated. If
45
+ there is a word boundary, the space should be in the continuation.
46
+ For example, context="hello" continuation=" world" is correct.
47
+
48
+ :return: list[tuple[float, bool]]
49
+ A list of pairs (logprob, isgreedy)
50
+ `logprob: float`
51
+ The log probability of `continuation`.
52
+ `isgreedy`:
53
+ Whether `continuation` would be generated by greedy sampling from `context`.
54
+ """
55
+ pass
56
+
57
+ @abc.abstractmethod
58
+ def loglikelihood_rolling(self, requests) -> List[float]:
59
+ """Compute full log-likelihood of a string, with no truncation, for perplexity computation
60
+ - We will use the full max context length of the model.
61
+ - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
62
+ the max context length.
63
+ - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
64
+ which may simply concatenate multiple documents together.
65
+ - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
66
+ multiple chunks, the last input will still a full-sized context.
67
+ Example:
68
+ Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
69
+ Prefix: BOS/EOS
70
+ Max context length: 4
71
+ Resulting input/prediction pairs:
72
+
73
+ INPUT: BOS 0 1 2
74
+ PRED: 0 1 2 3
75
+
76
+ INPUT: 3 4 5 6
77
+ PRED: 4 5 6 7
78
+
79
+ INPUT: 5 6 7 8
80
+ PRED: 8 9
81
+
82
+ Observe that:
83
+ 1. Each token is predicted exactly once
84
+ 2. For the last pair, we provide the full context, but only score the last two tokens
85
+
86
+ :param requests: list[Instance]
87
+ A list of Instance objects with property `args` which returns a tuple (context,).
88
+ string: str
89
+ String for which we are computing overall loglikelihood
90
+ :return: list[tuple[float]]
91
+ A list of tuples (logprob,)
92
+ logprob: float
93
+ The log probability of `context` conditioned on the BOS/EOS token.
94
+ Can also be overridden for custom cases by `prefix_token_id`.
95
+ """
96
+ pass
97
+
98
+ # TODO: Add an optional max length
99
+ @abc.abstractmethod
100
+ def generate_until(self, requests) -> List[str]:
101
+ """Generate greedily until a stopping sequence
102
+
103
+ :param requests: list[Instance]
104
+ A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
105
+ context: str
106
+ Context string
107
+ gen_kwargs: dict
108
+ A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
109
+ :return: list[str]
110
+ A list of model generated continuations.
111
+ continuation: str
112
+ The generated continuation.
113
+ """
114
+ pass
115
+
116
+ def apply_chat_template(
117
+ self, chat_history: List[Dict[str, str]], add_generation_prompt=True
118
+ ) -> str:
119
+ """
120
+ Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
121
+
122
+ :param chat_history: list[dict[str, str]]
123
+ A list of dictionaries with keys 'role' and 'content'.
124
+ Values are strings representing the role name and the content of the message, respectively.
125
+ :param add_generation_prompt: bool
126
+ Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
127
+ :return: str
128
+ A string representing the chat history in a format that can be used as input to the LM.
129
+ """
130
+ raise NotImplementedError(
131
+ "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
132
+ )
133
+
134
+ @classmethod
135
+ def create_from_arg_string(
136
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
137
+ ) -> T:
138
+ """
139
+ Creates an instance of the LM class using the given argument string and additional config.
140
+
141
+ Parameters:
142
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
143
+ - additional_config: Optional dictionary containing additional configuration parameters.
144
+
145
+ Returns:
146
+ - Instance of the LM class.
147
+ """
148
+ additional_config = {} if additional_config is None else additional_config
149
+ args = utils.simple_parse_args_string(arg_string)
150
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
151
+ return cls(**args, **args2)
152
+
153
+ @classmethod
154
+ def create_from_arg_obj(
155
+ cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
156
+ ) -> T:
157
+ """
158
+ Creates an instance of the LM class using the given arg_obj
159
+
160
+ Parameters:
161
+ - arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
162
+ - additional_config: Optional dictionary containing additional configuration parameters.
163
+
164
+ Returns:
165
+ - Instance of the LM class.
166
+ """
167
+
168
+ additional_config = {} if additional_config is None else additional_config
169
+ additional_config = {
170
+ k: v for k, v in additional_config.items() if v is not None
171
+ }
172
+
173
+ return cls(**arg_dict, **additional_config)
174
+
175
+ @property
176
+ def rank(self):
177
+ # used in the case of parallelism. Hardcoded to
178
+ # ensure no errors arise using API models which do
179
+ # not support multi-device parallelism nor expect it.
180
+ return self._rank
181
+
182
+ @property
183
+ def world_size(self):
184
+ # used in the case of parallelism. Hardcoded to
185
+ # ensure no errors arise using API models which do
186
+ # not support multi-device parallelism nor expect it.
187
+ return self._world_size
188
+
189
+ @property
190
+ def tokenizer_name(self) -> str:
191
+ """Must be defined for LM subclasses which implement Chat Templating.
192
+ Should return the name of the tokenizer or chat template used.
193
+ Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
194
+ """
195
+ raise NotImplementedError(
196
+ "To use this model with chat templates, please implement the 'tokenizer_name' property."
197
+ )
198
+
199
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
200
+ """Returns the chat template structure for user/assistant messages if a template is provided.
201
+ This method is intended to be overridden in a subclass to define a specific chat template format.
202
+ For models that do not support chat templates, this method returns None by default.
203
+ """
204
+
205
+ return ""
206
+
207
+ def set_cache_hook(self, cache_hook) -> None:
208
+ self.cache_hook = cache_hook
209
+
210
+
211
+ ### SQLite-based caching of LM responses
212
+ def hash_args(attr, args):
213
+ dat = json.dumps([attr] + list(args))
214
+ return hashlib.sha256(dat.encode("utf-8")).hexdigest()
215
+
216
+
217
+ class CacheHook:
218
+ def __init__(self, cachinglm) -> None:
219
+ if cachinglm is None:
220
+ self.dbdict = None
221
+ return
222
+
223
+ self.dbdict = cachinglm.dbdict
224
+
225
+ def add_partial(self, attr, req, res) -> None:
226
+ if self.dbdict is None:
227
+ return
228
+ hsh = hash_args(attr, req)
229
+ self.dbdict[hsh] = res
230
+
231
+
232
+ class CachingLM:
233
+ def __init__(self, lm, cache_db) -> None:
234
+ """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
235
+
236
+ :param lm: LM
237
+ Underlying LM
238
+ :param cache_db: str
239
+ Path to cache db
240
+ """
241
+ self.lm = lm
242
+ self.cache_db = cache_db
243
+ if os.path.dirname(cache_db):
244
+ os.makedirs(os.path.dirname(cache_db), exist_ok=True)
245
+ self.dbdict = SqliteDict(cache_db, autocommit=True)
246
+
247
+ # add hook to lm
248
+ lm.set_cache_hook(self.get_cache_hook())
249
+
250
+ def __getattr__(self, attr: str):
251
+ lm_attr = getattr(self.lm, attr)
252
+ if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
253
+ eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
254
+ return lm_attr
255
+
256
+ def fn(requests):
257
+ res = []
258
+ remaining_reqs = []
259
+ warned = False
260
+ # figure out which ones are cached and which ones are new
261
+ eval_logger.info(
262
+ f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
263
+ )
264
+ for req in tqdm(requests, desc="Checking cached requests"):
265
+ hsh = hash_args(attr, req.args)
266
+ if attr == "generate_until" and req.args[1].get("do_sample", False):
267
+ # when we are doing non-greedy generation, don't use the cache
268
+ # (else every "randomly sampled" generation would be identical for repeats > 1).
269
+ if not warned:
270
+ eval_logger.warning(
271
+ f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
272
+ )
273
+ warned = True
274
+ res.append(None)
275
+ remaining_reqs.append(req)
276
+ elif hsh in self.dbdict:
277
+ ob = self.dbdict[hsh]
278
+
279
+ assert ob is not None
280
+
281
+ res.append(ob)
282
+ else:
283
+ res.append(None)
284
+ remaining_reqs.append(req)
285
+ eval_logger.info(
286
+ f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
287
+ )
288
+ if remaining_reqs:
289
+ # actually run the LM on the requests that do not have cached results
290
+ rem_res = getattr(self.lm, attr)(remaining_reqs)
291
+ else:
292
+ rem_res = []
293
+
294
+ # stick the new ones back into the list and also cache any of the new ones
295
+ resptr = 0
296
+ for req, r in zip(remaining_reqs, rem_res):
297
+ while res[resptr] is not None:
298
+ resptr += 1
299
+
300
+ res[resptr] = r
301
+
302
+ # caching
303
+ hsh = hash_args(attr, req.args)
304
+ self.dbdict[hsh] = r
305
+ self.dbdict.commit()
306
+
307
+ return res
308
+
309
+ return fn
310
+
311
+ def get_cache_hook(self):
312
+ return CacheHook(self)
313
+
314
+
315
+ class TemplateLM(LM):
316
+ """
317
+ A class acting as intermediary between the LM base class
318
+ and boilerplate often included in other LM subclasses.
319
+ """
320
+
321
+ tokenizer = None
322
+
323
+ @property
324
+ @abc.abstractmethod
325
+ def eot_token_id(self):
326
+ pass
327
+
328
+ @property
329
+ def prefix_token_id(self):
330
+ # it is used as prefix for loglikelihood
331
+ return self.eot_token_id
332
+
333
+ @abc.abstractmethod
334
+ def tok_encode(self, string: str, **kwargs) -> List[int]:
335
+ """
336
+ Tokenize a string using the model's tokenizer and return a list of token IDs.
337
+ """
338
+ pass
339
+
340
+ @abc.abstractmethod
341
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
342
+ pass
343
+
344
+ def _encode_pair(
345
+ self, context: str, continuation: str
346
+ ) -> Tuple[List[int], List[int]]:
347
+ n_spaces = len(context) - len(context.rstrip())
348
+ if n_spaces > 0:
349
+ continuation = context[-n_spaces:] + continuation
350
+ context = context[:-n_spaces]
351
+
352
+ model_class = getattr(self, "AUTO_MODEL_CLASS", None)
353
+
354
+ if model_class == transformers.AutoModelForSeq2SeqLM:
355
+ context_enc = self.tok_encode(context)
356
+ continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
357
+ else:
358
+ whole_enc = self.tok_encode(context + continuation)
359
+ context_enc = self.tok_encode(context)
360
+
361
+ context_enc_len = len(context_enc)
362
+ continuation_enc = whole_enc[context_enc_len:]
363
+
364
+ return context_enc, continuation_enc
365
+
366
+ def loglikelihood(
367
+ self, requests, disable_tqdm: bool = False
368
+ ) -> List[Tuple[float, bool]]:
369
+ new_reqs = []
370
+ for context, continuation in [req.args for req in requests]:
371
+ if context == "":
372
+ # BOS or EOS as context
373
+ context_enc, continuation_enc = (
374
+ [self.prefix_token_id],
375
+ self.tok_encode(continuation),
376
+ )
377
+ else:
378
+ context_enc, continuation_enc = self._encode_pair(context, continuation)
379
+
380
+ new_reqs.append(((context, continuation), context_enc, continuation_enc))
381
+
382
+ return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
383
+
384
+ @abc.abstractmethod
385
+ def loglikelihood_rolling(
386
+ self, requests, disable_tqdm: bool = False
387
+ ) -> List[float]:
388
+ pass
389
+
390
+ @abc.abstractmethod
391
+ def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
392
+ pass
393
+
394
+ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
395
+ """
396
+ Set and get the appropriate chat template for the model.
397
+ This method sets the tokenizer's chat_template and returns the template string for reproducibility.
398
+
399
+ The template selection logic is adapted from the Transformers library's `apply_chat_template`
400
+ method in the Tokenizer class. The original implementation can be found at:
401
+ https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
402
+
403
+ This method ensures that the right template is chosen based on the following:
404
+ 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
405
+ 1. If the model's tokenizer has multiple templates:
406
+ a. Use the specified template if it exists in the dictionary.
407
+ b. Use the default template from the list if no specific template is provided.
408
+ c. Raise an error if no default template exists and no specific template is provided.
409
+ 2. If the model's tokenizer has a single template or no template:
410
+ a. Use the tokenizer's chat template if available.
411
+ b. Fall back to the default chat template if no tokenizer chat template exists.
412
+
413
+ Args:
414
+ chat_template (Union[bool, str]): Specifies the chat template to use.
415
+ - If False or None, no template is applied.
416
+ - If True, the default or only available template is used.
417
+ - If a string, the template with the matching name is used.
418
+
419
+ Returns:
420
+ Optional[str]: The selected chat template, or None if no template is applied.
421
+ """
422
+ if self.tokenizer is None:
423
+ return ""
424
+
425
+ if chat_template is False or chat_template is None:
426
+ eval_logger.warning(
427
+ "model.chat_template was called with the chat_template set to False or None. "
428
+ "Therefore no chat template will be applied. Make sure this is an intended behavior."
429
+ )
430
+ return None
431
+
432
+ # Convert boolean chat_template to None to ensure compatibility with the adapted logic
433
+ if isinstance(chat_template, bool):
434
+ chat_template = None
435
+ using_default_template = False
436
+
437
+ # First, handle the cases when the model has a dict of multiple templates
438
+ try:
439
+ template = (
440
+ self.tokenizer.chat_template or self.tokenizer.default_chat_template
441
+ )
442
+ except AttributeError:
443
+ return None
444
+
445
+ if isinstance(template, dict):
446
+ using_default_dict = self.tokenizer.chat_template is None
447
+
448
+ if chat_template is not None:
449
+ if chat_template in template:
450
+ selected_template = template[chat_template]
451
+ if using_default_dict:
452
+ using_default_template = True
453
+ else:
454
+ raise ValueError(
455
+ f"The specified chat template '{chat_template}' is not available. "
456
+ f"Available template names are {sorted(template.keys())}."
457
+ )
458
+ else:
459
+ # If user didn't pass a chat template, use the default template from the dict
460
+ if "default" in template:
461
+ selected_template = template["default"]
462
+ using_default_template = True
463
+ else:
464
+ raise ValueError(
465
+ "This model has multiple chat templates with no default specified! Please either pass a chat "
466
+ "template or the name of the template you wish to use to the `chat_template` argument. Available "
467
+ f"template names are {sorted(template.keys())}."
468
+ )
469
+
470
+ # Cases when the model has a single template or no template
471
+ else:
472
+ # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
473
+ if isinstance(chat_template, str):
474
+ eval_logger.warning(
475
+ "Chat template name provided, but the tokenizer's chat template is not a dictionary. "
476
+ "Using the tokenizer's chat template or the default template instead."
477
+ )
478
+ if self.tokenizer.chat_template is not None:
479
+ selected_template = self.tokenizer.chat_template
480
+ else:
481
+ selected_template = self.tokenizer.default_chat_template
482
+ using_default_template = True
483
+
484
+ if using_default_template:
485
+ eval_logger.warning(
486
+ "No chat template is set for this tokenizer, falling back to a default class-level template. This is "
487
+ "very error-prone, because models are often trained with templates different from the class default! "
488
+ "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
489
+ "point any code depending on them will stop working. We recommend setting a valid chat template before "
490
+ "then to ensure that this model continues working without issues."
491
+ )
492
+
493
+ return selected_template
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/registry.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Callable, Dict, Union
3
+
4
+ import evaluate as hf_evaluate
5
+
6
+ from lm_eval.api.model import LM
7
+
8
+
9
+ eval_logger = logging.getLogger(__name__)
10
+
11
+ MODEL_REGISTRY = {}
12
+
13
+
14
+ def register_model(*names):
15
+ # either pass a list or a single alias.
16
+ # function receives them as a tuple of strings
17
+
18
+ def decorate(cls):
19
+ for name in names:
20
+ assert issubclass(cls, LM), (
21
+ f"Model '{name}' ({cls.__name__}) must extend LM class"
22
+ )
23
+
24
+ assert name not in MODEL_REGISTRY, (
25
+ f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
26
+ )
27
+
28
+ MODEL_REGISTRY[name] = cls
29
+ return cls
30
+
31
+ return decorate
32
+
33
+
34
+ def get_model(model_name):
35
+ try:
36
+ return MODEL_REGISTRY[model_name]
37
+ except KeyError:
38
+ raise ValueError(
39
+ f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
40
+ )
41
+
42
+
43
+ TASK_REGISTRY = {}
44
+ GROUP_REGISTRY = {}
45
+ ALL_TASKS = set()
46
+ func2task_index = {}
47
+
48
+
49
+ def register_task(name):
50
+ def decorate(fn):
51
+ assert name not in TASK_REGISTRY, (
52
+ f"task named '{name}' conflicts with existing registered task!"
53
+ )
54
+
55
+ TASK_REGISTRY[name] = fn
56
+ ALL_TASKS.add(name)
57
+ func2task_index[fn.__name__] = name
58
+ return fn
59
+
60
+ return decorate
61
+
62
+
63
+ def register_group(name):
64
+ def decorate(fn):
65
+ func_name = func2task_index[fn.__name__]
66
+ if name in GROUP_REGISTRY:
67
+ GROUP_REGISTRY[name].append(func_name)
68
+ else:
69
+ GROUP_REGISTRY[name] = [func_name]
70
+ ALL_TASKS.add(name)
71
+ return fn
72
+
73
+ return decorate
74
+
75
+
76
+ OUTPUT_TYPE_REGISTRY = {}
77
+ METRIC_REGISTRY = {}
78
+ METRIC_AGGREGATION_REGISTRY = {}
79
+ AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
80
+ HIGHER_IS_BETTER_REGISTRY = {}
81
+ FILTER_REGISTRY = {}
82
+
83
+ DEFAULT_METRIC_REGISTRY = {
84
+ "loglikelihood": [
85
+ "perplexity",
86
+ "acc",
87
+ ],
88
+ "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
89
+ "multiple_choice": ["acc", "acc_norm"],
90
+ "generate_until": ["exact_match"],
91
+ }
92
+
93
+
94
+ def register_metric(**args):
95
+ # TODO: do we want to enforce a certain interface to registered metrics?
96
+ def decorate(fn):
97
+ assert "metric" in args
98
+ name = args["metric"]
99
+
100
+ for key, registry in [
101
+ ("metric", METRIC_REGISTRY),
102
+ ("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
103
+ ("aggregation", METRIC_AGGREGATION_REGISTRY),
104
+ ]:
105
+ if key in args:
106
+ value = args[key]
107
+ assert value not in registry, (
108
+ f"{key} named '{value}' conflicts with existing registered {key}!"
109
+ )
110
+
111
+ if key == "metric":
112
+ registry[name] = fn
113
+ elif key == "aggregation":
114
+ registry[name] = AGGREGATION_REGISTRY[value]
115
+ else:
116
+ registry[name] = value
117
+
118
+ return fn
119
+
120
+ return decorate
121
+
122
+
123
+ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
124
+ if not hf_evaluate_metric:
125
+ if name in METRIC_REGISTRY:
126
+ return METRIC_REGISTRY[name]
127
+ else:
128
+ eval_logger.warning(
129
+ f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
130
+ )
131
+
132
+ try:
133
+ metric_object = hf_evaluate.load(name)
134
+ return metric_object.compute
135
+ except Exception:
136
+ eval_logger.error(
137
+ f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
138
+ )
139
+
140
+
141
+ def register_aggregation(name: str):
142
+ def decorate(fn):
143
+ assert name not in AGGREGATION_REGISTRY, (
144
+ f"aggregation named '{name}' conflicts with existing registered aggregation!"
145
+ )
146
+
147
+ AGGREGATION_REGISTRY[name] = fn
148
+ return fn
149
+
150
+ return decorate
151
+
152
+
153
+ def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
154
+ try:
155
+ return AGGREGATION_REGISTRY[name]
156
+ except KeyError:
157
+ eval_logger.warning(f"{name} not a registered aggregation metric!")
158
+
159
+
160
+ def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
161
+ try:
162
+ return METRIC_AGGREGATION_REGISTRY[name]
163
+ except KeyError:
164
+ eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
165
+
166
+
167
+ def is_higher_better(metric_name) -> bool:
168
+ try:
169
+ return HIGHER_IS_BETTER_REGISTRY[metric_name]
170
+ except KeyError:
171
+ eval_logger.warning(
172
+ f"higher_is_better not specified for metric '{metric_name}'!"
173
+ )
174
+
175
+
176
+ def register_filter(name):
177
+ def decorate(cls):
178
+ if name in FILTER_REGISTRY:
179
+ eval_logger.info(
180
+ f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
181
+ )
182
+ FILTER_REGISTRY[name] = cls
183
+ return cls
184
+
185
+ return decorate
186
+
187
+
188
+ def get_filter(filter_name: Union[str, Callable]) -> Callable:
189
+ try:
190
+ return FILTER_REGISTRY[filter_name]
191
+ except KeyError as e:
192
+ if callable(filter_name):
193
+ return filter_name
194
+ else:
195
+ eval_logger.warning(f"filter `{filter_name}` is not registered!")
196
+ raise e
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/samplers.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from functools import partial
4
+ from typing import TYPE_CHECKING, Iterable, Optional, Union
5
+
6
+ import datasets
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ from random import Random
11
+
12
+ from lm_eval.api.task import ConfigurableTask, Task
13
+
14
+ eval_logger = logging.getLogger("lm-eval")
15
+
16
+
17
+ class ContextSampler:
18
+ def __init__(
19
+ self,
20
+ docs: list[dict],
21
+ task: Union["Task", "ConfigurableTask"],
22
+ fewshot_indices: Optional[Iterable] = None,
23
+ rnd: Optional["Random"] = None,
24
+ ) -> None:
25
+ self.rnd = rnd
26
+ if not self.rnd:
27
+ raise ValueError(
28
+ "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
29
+ )
30
+
31
+ self.task = task
32
+ self.config = task._config
33
+
34
+ self.target_delimiter = self.config.target_delimiter
35
+ self.fewshot_delimiter = self.config.fewshot_delimiter
36
+
37
+ if (
38
+ self.config.fewshot_config is not None
39
+ and self.config.fewshot_config.get("doc_to_text", None) is not None
40
+ ):
41
+ self.doc_to_text = partial(
42
+ self.task.doc_to_text,
43
+ doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
44
+ )
45
+ else:
46
+ self.doc_to_text = self.task.doc_to_text
47
+
48
+ if (
49
+ self.config.fewshot_config is not None
50
+ and self.config.fewshot_config.get("doc_to_target", None) is not None
51
+ ):
52
+ self.doc_to_target = partial(
53
+ self.task.doc_to_target,
54
+ doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
55
+ )
56
+ else:
57
+ self.doc_to_target = self.task.doc_to_target
58
+
59
+ if (
60
+ self.config.fewshot_config is not None
61
+ and self.config.fewshot_config.get("doc_to_choice", None) is not None
62
+ ):
63
+ self.doc_to_choice = partial(
64
+ self.task.doc_to_choice,
65
+ doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
66
+ )
67
+ else:
68
+ self.doc_to_choice = self.task.doc_to_choice
69
+
70
+ self.docs = docs # HF dataset split, provided by task._fewshot_docs()
71
+ if fewshot_indices: # subset few-shot docs from
72
+ if not isinstance(self.docs, datasets.Dataset):
73
+ raise ValueError(
74
+ "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
75
+ )
76
+ self.docs = self.docs.select(fewshot_indices)
77
+
78
+ def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
79
+ # draw an extra fewshot sample if using same split as evaluating on
80
+ prefix = gen_prefix + " " if gen_prefix else ""
81
+ n_samples = (
82
+ num_fewshot + 1
83
+ if self.config.fewshot_split == self.config.test_split
84
+ else num_fewshot
85
+ )
86
+
87
+ # draw `n_samples` docs from fewshot_docs
88
+ fewshotex = self.sample(n_samples)
89
+
90
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
91
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
92
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
93
+
94
+ labeled_examples = ""
95
+ for doc in selected_docs:
96
+ doc_content = self.doc_to_text(doc)
97
+ doc_target = self.doc_to_target(doc)
98
+ if self.config.doc_to_choice is None or isinstance(doc_content, str):
99
+ labeled_examples += doc_content
100
+ else:
101
+ labeled_examples += self.doc_to_choice(doc)[doc_content]
102
+
103
+ if doc_target != "":
104
+ if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
105
+ # TODO: add logger warn once here.
106
+ warnings.warn(
107
+ "Both target_delimiter and target start with a space. This may cause issues.",
108
+ Warning,
109
+ stacklevel=2,
110
+ )
111
+ labeled_examples += self.target_delimiter
112
+ labeled_examples += prefix
113
+ labeled_examples += (
114
+ str(doc_target[0])
115
+ if isinstance(doc_target, list)
116
+ else doc_target
117
+ if self.config.doc_to_choice is None or isinstance(doc_target, str)
118
+ else str(self.doc_to_choice(doc)[doc_target])
119
+ )
120
+ labeled_examples += self.fewshot_delimiter
121
+
122
+ return labeled_examples
123
+
124
+ def get_chat_context(
125
+ self,
126
+ doc: dict,
127
+ num_fewshot: int,
128
+ fewshot_as_multiturn: bool = False,
129
+ gen_prefix: Optional[str] = None,
130
+ ):
131
+ # TODO: Do we need any other delimiter
132
+ prefix = gen_prefix + " " if gen_prefix else ""
133
+ chat_history = []
134
+ # draw an extra fewshot sample if using same split as evaluating on
135
+ n_samples = (
136
+ num_fewshot + 1
137
+ if self.config.fewshot_split == self.config.test_split
138
+ else num_fewshot
139
+ )
140
+ # draw `n_samples` docs from fewshot_docs
141
+ fewshotex = self.sample(n_samples)
142
+
143
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
144
+ # TODO: should we just stop people from using fewshot from same split as evaluating?
145
+ selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
146
+
147
+ if fewshot_as_multiturn:
148
+ for doc in selected_docs:
149
+ doc_content = self.doc_to_text(doc)
150
+ doc_target = self.doc_to_target(doc)
151
+ chat_history.append(
152
+ {
153
+ "role": "user",
154
+ "content": doc_content
155
+ if self.config.doc_to_choice is None
156
+ or isinstance(doc_content, str)
157
+ else self.doc_to_choice(doc)[doc_content],
158
+ }
159
+ )
160
+ chat_history.append(
161
+ {
162
+ "role": "assistant",
163
+ "content": prefix + str(doc_target[0])
164
+ if isinstance(doc_target, list)
165
+ else prefix + doc_target
166
+ if self.config.doc_to_choice is None
167
+ or isinstance(doc_target, str)
168
+ else prefix + str(self.doc_to_choice(doc)[doc_target]),
169
+ }
170
+ )
171
+ else:
172
+ # get fewshot context as one user turn
173
+ chat_history.append(
174
+ {
175
+ "role": "user",
176
+ "content": self.get_context(
177
+ doc, num_fewshot, gen_prefix=gen_prefix
178
+ ),
179
+ }
180
+ )
181
+
182
+ return chat_history
183
+
184
+ def sample(self, n: int):
185
+ """
186
+ Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
187
+ """
188
+
189
+ return self.rnd.sample(self.docs, n)
190
+
191
+
192
+ class FirstNSampler(ContextSampler):
193
+ def sample(self, n: int) -> None:
194
+ """
195
+ Draw the first `n` samples in order from the specified split.
196
+ Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
197
+ """
198
+ assert n <= len(self.docs), (
199
+ f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
200
+ )
201
+ return self.docs[:n]
202
+
203
+
204
+ class BalancedSampler(ContextSampler):
205
+ def sample(self, n: int) -> None:
206
+ """
207
+ TODO: this should return approximately class-balanced samples from our fewshot examples.
208
+ TODO: what order should they be in? maybe random?
209
+ """
210
+
211
+ pass
212
+
213
+
214
+ class ManualSampler(ContextSampler):
215
+ def sample(self, n: int) -> None:
216
+ """ """
217
+ pass
218
+
219
+
220
+ SAMPLER_REGISTRY = {
221
+ "default": ContextSampler,
222
+ "first_n": FirstNSampler,
223
+ }
224
+
225
+
226
+ def get_sampler(name: str):
227
+ try:
228
+ return SAMPLER_REGISTRY[name]
229
+ except KeyError:
230
+ raise ValueError(
231
+ f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
232
+ )
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/api/task.py ADDED
@@ -0,0 +1,1839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import ast
3
+ import logging
4
+ import random
5
+ import re
6
+ from collections.abc import Callable
7
+ from copy import deepcopy
8
+ from dataclasses import asdict, dataclass
9
+ from inspect import getsource
10
+ from typing import (
11
+ Any,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Mapping,
18
+ Optional,
19
+ Tuple,
20
+ Union,
21
+ )
22
+
23
+ import datasets
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+
27
+ from lm_eval import utils
28
+ from lm_eval.api import samplers
29
+ from lm_eval.api.instance import Instance, OutputType
30
+ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
31
+ from lm_eval.api.registry import (
32
+ AGGREGATION_REGISTRY,
33
+ DEFAULT_METRIC_REGISTRY,
34
+ get_aggregation,
35
+ get_metric,
36
+ get_metric_aggregation,
37
+ is_higher_better,
38
+ )
39
+ from lm_eval.caching.cache import load_from_cache, save_to_cache
40
+ from lm_eval.filters import build_filter_ensemble
41
+ from lm_eval.prompts import get_prompt
42
+
43
+
44
+ ALL_OUTPUT_TYPES = [
45
+ "loglikelihood",
46
+ "multiple_choice",
47
+ "loglikelihood_rolling",
48
+ "generate_until",
49
+ ]
50
+
51
+ eval_logger = logging.getLogger(__name__)
52
+
53
+
54
+ @dataclass
55
+ class TaskConfig(dict):
56
+ # task naming/registry
57
+ task: Optional[str] = None
58
+ task_alias: Optional[str] = None
59
+ tag: Optional[Union[str, list]] = None
60
+ # HF dataset options.
61
+ # which dataset to use,
62
+ # and what splits for what purpose
63
+ custom_dataset: Optional[Callable] = None
64
+ dataset_path: Optional[str] = None
65
+ dataset_name: Optional[str] = None
66
+ dataset_kwargs: Optional[dict] = None
67
+ training_split: Optional[str] = None
68
+ validation_split: Optional[str] = None
69
+ test_split: Optional[str] = None
70
+ fewshot_split: Optional[str] = (
71
+ None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
72
+ )
73
+ # formatting / prompting options.
74
+ # see docs/advanced_task_guide.md for more info
75
+ process_docs: Optional[Callable] = None
76
+ doc_to_text: Optional[Union[Callable, str]] = None
77
+ doc_to_target: Optional[Union[Callable, str]] = None
78
+ doc_to_image: Union[Callable, str] = None
79
+ doc_to_audio: Union[Callable, str] = None
80
+ unsafe_code: bool = False
81
+ doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
82
+ process_results: Optional[Union[Callable, str]] = None
83
+ use_prompt: Optional[str] = None
84
+ description: str = ""
85
+ target_delimiter: str = " "
86
+ fewshot_delimiter: str = "\n\n"
87
+ fewshot_config: Optional[dict] = None
88
+ # runtime configuration options
89
+ num_fewshot: Optional[int] = None
90
+ # scoring options
91
+ metric_list: Optional[list] = None
92
+ output_type: OutputType = "generate_until"
93
+ generation_kwargs: Optional[dict] = None
94
+ repeats: int = 1
95
+ filter_list: Optional[Union[str, list]] = None
96
+ should_decontaminate: bool = False
97
+ doc_to_decontamination_query: Optional[str] = None
98
+ gen_prefix: Optional[str] = None
99
+ metadata: Optional[dict] = (
100
+ None # by default, not used in the code. allows for users to pass arbitrary info to tasks
101
+ )
102
+
103
+ def __post_init__(self) -> None:
104
+ if self.generation_kwargs is not None:
105
+ if self.output_type != "generate_until":
106
+ eval_logger.warning(
107
+ f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
108
+ )
109
+
110
+ if "temperature" in self.generation_kwargs:
111
+ self.generation_kwargs["temperature"] = float(
112
+ self.generation_kwargs["temperature"]
113
+ )
114
+
115
+ if "until" not in self.generation_kwargs:
116
+ self.generation_kwargs["until"] = [self.fewshot_delimiter]
117
+ else:
118
+ if self.output_type == "generate_until":
119
+ # ensure that we greedily generate in absence of explicit arguments otherwise
120
+ self.generation_kwargs = {
121
+ "until": (
122
+ None
123
+ if self.fewshot_delimiter is None
124
+ else [self.fewshot_delimiter]
125
+ ),
126
+ "do_sample": False,
127
+ }
128
+
129
+ def __getitem__(self, item):
130
+ return getattr(self, item)
131
+
132
+ def __setitem__(self, item, value):
133
+ return setattr(self, item, value)
134
+
135
+ def to_dict(self, keep_callable: bool = False) -> dict:
136
+ """dumps the current config as a dictionary object, as a printable format.
137
+ null fields will not be printed.
138
+ Used for dumping results alongside full task configuration
139
+
140
+ :return: dict
141
+ A printable dictionary version of the TaskConfig object.
142
+
143
+ # TODO: should any default value in the TaskConfig not be printed?
144
+ """
145
+ cfg_dict = asdict(self)
146
+ # remove values that are `None`
147
+ for k, v in list(cfg_dict.items()):
148
+ if v is None:
149
+ cfg_dict.pop(k)
150
+ elif k == "metric_list":
151
+ for metric_dict in v:
152
+ for metric_key, metric_value in metric_dict.items():
153
+ if callable(metric_value):
154
+ metric_dict[metric_key] = self.serialize_function(
155
+ metric_value, keep_callable=keep_callable
156
+ )
157
+ cfg_dict[k] = v
158
+ elif callable(v):
159
+ cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
160
+ return cfg_dict
161
+
162
+ def serialize_function(
163
+ self, value: Union[Callable, str], keep_callable=False
164
+ ) -> Union[Callable, str]:
165
+ """Serializes a given function or string.
166
+
167
+ If 'keep_callable' is True, the original callable is returned.
168
+ Otherwise, attempts to return the source code of the callable using 'getsource'.
169
+ """
170
+ if keep_callable:
171
+ return value
172
+ else:
173
+ try:
174
+ return getsource(value)
175
+ except (TypeError, OSError):
176
+ return str(value)
177
+
178
+
179
+ class Task(abc.ABC):
180
+ """A task represents an entire benchmark including its dataset, problems,
181
+ answers, and evaluation methods. See BoolQ for a simple example implementation
182
+
183
+ A `doc` can be any python object which represents one instance of evaluation.
184
+ This is usually a dictionary e.g.
185
+ {"question": ..., "answer": ...} or
186
+ {"question": ..., question, answer)
187
+ """
188
+
189
+ VERSION: Optional[Union[int, str]] = None
190
+
191
+ # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
192
+ # or a path to a custom `datasets` loading script.
193
+ DATASET_PATH: Optional[str] = None
194
+
195
+ # The name of a subset within `DATASET_PATH`.
196
+ DATASET_NAME: Optional[str] = None
197
+
198
+ OUTPUT_TYPE: Optional[OutputType] = None
199
+
200
+ def __init__(
201
+ self,
202
+ data_dir: Optional[str] = None,
203
+ cache_dir: Optional[str] = None,
204
+ download_mode: Optional[datasets.DownloadMode] = None,
205
+ config: Optional[Mapping] = None, # Union[dict, TaskConfig]
206
+ ) -> None:
207
+ """
208
+ :param data_dir: str
209
+ Stores the path to a local folder containing the `Task`'s data files.
210
+ Use this to specify the path to manually downloaded data (usually when
211
+ the dataset is not publicly accessible).
212
+ :param cache_dir: str
213
+ The directory to read/write the `Task` dataset. This follows the
214
+ HuggingFace `datasets` API with the default cache directory located at:
215
+ `~/.cache/huggingface/datasets`
216
+ NOTE: You can change the cache location globally for a given process
217
+ to another directory:
218
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
219
+ :param download_mode: datasets.DownloadMode
220
+ How to treat pre-existing `Task` downloads and data.
221
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
222
+ Reuse download and reuse dataset.
223
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
224
+ Reuse download with fresh dataset.
225
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
226
+ Fresh download and fresh dataset.
227
+ """
228
+ self.download(data_dir, cache_dir, download_mode)
229
+ self._training_docs: Optional[list] = None
230
+ self._fewshot_docs: Optional[list] = None
231
+ self._instances: Optional[List[Instance]] = None
232
+
233
+ self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
234
+
235
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
236
+ self.fewshot_rnd: Optional[random.Random] = (
237
+ None # purposely induce errors in case of improper usage
238
+ )
239
+
240
+ def download(
241
+ self,
242
+ data_dir: Optional[str] = None,
243
+ cache_dir: Optional[str] = None,
244
+ download_mode=None,
245
+ ) -> None:
246
+ """Downloads and returns the task dataset.
247
+ Override this method to download the dataset from a custom API.
248
+
249
+ :param data_dir: str
250
+ Stores the path to a local folder containing the `Task`'s data files.
251
+ Use this to specify the path to manually downloaded data (usually when
252
+ the dataset is not publicly accessible).
253
+ :param cache_dir: str
254
+ The directory to read/write the `Task` dataset. This follows the
255
+ HuggingFace `datasets` API with the default cache directory located at:
256
+ `~/.cache/huggingface/datasets`
257
+ NOTE: You can change the cache location globally for a given process
258
+ by setting the shell environment variable, `HF_DATASETS_CACHE`,
259
+ to another directory:
260
+ `export HF_DATASETS_CACHE="/path/to/another/directory"`
261
+ :param download_mode: datasets.DownloadMode
262
+ How to treat pre-existing `Task` downloads and data.
263
+ - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
264
+ Reuse download and reuse dataset.
265
+ - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
266
+ Reuse download with fresh dataset.
267
+ - `datasets.DownloadMode.FORCE_REDOWNLOAD`
268
+ Fresh download and fresh dataset.
269
+ """
270
+ self.dataset = datasets.load_dataset(
271
+ path=self.DATASET_PATH,
272
+ name=self.DATASET_NAME,
273
+ data_dir=data_dir,
274
+ cache_dir=cache_dir,
275
+ download_mode=download_mode,
276
+ )
277
+
278
+ @property
279
+ def config(self) -> TaskConfig:
280
+ """Returns the TaskConfig associated with this class."""
281
+ return self._config
282
+
283
+ @abc.abstractmethod
284
+ def has_training_docs(self):
285
+ """Whether the task has a training set"""
286
+ pass
287
+
288
+ @abc.abstractmethod
289
+ def has_validation_docs(self):
290
+ """Whether the task has a validation set"""
291
+ pass
292
+
293
+ @abc.abstractmethod
294
+ def has_test_docs(self):
295
+ """Whether the task has a test set"""
296
+ pass
297
+
298
+ def training_docs(self) -> Iterable:
299
+ """
300
+ :return: Iterable[obj]
301
+ A iterable of any object, that doc_to_text can handle
302
+ """
303
+ return []
304
+
305
+ def validation_docs(self) -> Iterable:
306
+ """
307
+ :return: Iterable[obj]
308
+ A iterable of any object, that doc_to_text can handle
309
+ """
310
+ return []
311
+
312
+ def test_docs(self) -> Iterable:
313
+ """
314
+ :return: Iterable[obj]
315
+ A iterable of any object, that doc_to_text can handle
316
+ """
317
+ return []
318
+
319
+ def fewshot_docs(self) -> Iterable:
320
+ """
321
+ :return: Iterable[obj]
322
+ A iterable of any object, that doc_to_text can handle
323
+ """
324
+ if self.has_training_docs():
325
+ return self.training_docs()
326
+ elif self.has_validation_docs():
327
+ return self.validation_docs()
328
+ else:
329
+ if self.config.get("num_fewshot", 0) > 0:
330
+ eval_logger.warning(
331
+ f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
332
+ ", using test_docs as fewshot_docs but this is not recommended."
333
+ )
334
+ return self.test_docs()
335
+
336
+ def _process_doc(self, doc: dict) -> dict:
337
+ """
338
+ Override this to process (detokenize, strip, replace, etc.) individual
339
+ documents. This can be used in a map over documents of a data split.
340
+ E.g. `map(self._process_doc, self.dataset["validation"])`
341
+
342
+ :return: dict
343
+ The processed version of the specified `doc`.
344
+ """
345
+ return doc
346
+
347
+ @property
348
+ def instances(self) -> List[Instance]:
349
+ """After calling `task.build_all_requests()`, tasks
350
+ maintain a list of the dataset instances which will be evaluated.
351
+ """
352
+ return self._instances
353
+
354
+ def fewshot_examples(self, k, rnd):
355
+ if self._training_docs is None:
356
+ self._training_docs = list(self.training_docs())
357
+
358
+ return rnd.sample(self._training_docs, k)
359
+
360
+ def doc_to_decontamination_query(self, doc):
361
+ raise NotImplementedError(
362
+ "Override doc_to_decontamination_query with document specific decontamination query."
363
+ )
364
+
365
+ @abc.abstractmethod
366
+ def doc_to_text(self, doc):
367
+ pass
368
+
369
+ @abc.abstractmethod
370
+ def doc_to_target(self, doc):
371
+ pass
372
+
373
+ # not an abstractmethod because not every language-only task has to implement this
374
+ def doc_to_image(self, doc):
375
+ raise NotImplementedError
376
+
377
+ def doc_to_audio(self, doc):
378
+ raise NotImplementedError
379
+
380
+ def doc_to_prefix(self, doc):
381
+ return ""
382
+
383
+ def build_all_requests(
384
+ self,
385
+ *,
386
+ limit: Union[int, None] = None,
387
+ rank: int = 0,
388
+ world_size: int = 1,
389
+ cache_requests: bool = False,
390
+ rewrite_requests_cache: bool = False,
391
+ system_instruction: Optional[str] = None,
392
+ apply_chat_template: bool = False,
393
+ fewshot_as_multiturn: bool = False,
394
+ chat_template: Optional[Callable] = None,
395
+ tokenizer_name: str = "",
396
+ ) -> None:
397
+ """Build a set of Instances for a task, and store them in task.instances"""
398
+
399
+ # used with caching
400
+ og_limit = limit
401
+
402
+ cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
403
+ cache_key += "-chat_template" if apply_chat_template else ""
404
+ cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
405
+ cache_key += (
406
+ f"-system_prompt_hash{utils.hash_string(system_instruction)}"
407
+ if system_instruction is not None
408
+ else ""
409
+ )
410
+ cache_key += f"-tokenizer{tokenizer_name}"
411
+
412
+ cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
413
+
414
+ if cache_requests and cached_instances and not rewrite_requests_cache:
415
+ cached_instances = cached_instances[:limit]
416
+
417
+ flattened_instances = [
418
+ instance
419
+ for instance_group in cached_instances
420
+ for instance in instance_group
421
+ ]
422
+
423
+ self._instances = flattened_instances
424
+ return
425
+
426
+ eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
427
+
428
+ instances = []
429
+
430
+ # process all documents when caching is specified for simplicity
431
+ if (
432
+ cache_requests
433
+ and (not cached_instances or rewrite_requests_cache)
434
+ and limit is not None
435
+ ):
436
+ limit = None
437
+
438
+ doc_id_docs = list(
439
+ self.doc_iterator(rank=rank, limit=limit, world_size=world_size)
440
+ )
441
+
442
+ num_docs = len(doc_id_docs)
443
+
444
+ for doc_id, doc in tqdm(
445
+ doc_id_docs,
446
+ total=num_docs,
447
+ ):
448
+ # sample fewshot context #TODO: need to offset doc_id by rank now!
449
+ fewshot_ctx = self.fewshot_context(
450
+ doc,
451
+ 0 if self.config.num_fewshot is None else self.config.num_fewshot,
452
+ system_instruction,
453
+ apply_chat_template,
454
+ fewshot_as_multiturn,
455
+ chat_template,
456
+ gen_prefix=self.doc_to_prefix(doc),
457
+ )
458
+
459
+ # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
460
+ inst = self.construct_requests(
461
+ doc=doc,
462
+ ctx=fewshot_ctx,
463
+ metadata=(self.config["task"], doc_id, self.config.repeats),
464
+ apply_chat_template=apply_chat_template,
465
+ chat_template=chat_template,
466
+ )
467
+
468
+ if not isinstance(inst, list):
469
+ inst = [inst]
470
+
471
+ instances.append(inst)
472
+
473
+ # now flatten, this is to allow slicing to work with pickles
474
+
475
+ sliced_instances = instances[:og_limit]
476
+
477
+ flattened_instances = [
478
+ instance
479
+ for instance_group in sliced_instances
480
+ for instance in instance_group
481
+ ]
482
+
483
+ self._instances = flattened_instances
484
+
485
+ if len(self._instances) == 0:
486
+ raise ValueError("task.build_requests() did not find any docs!")
487
+
488
+ if cache_requests and (not cached_instances or rewrite_requests_cache):
489
+ save_to_cache(file_name=cache_key, obj=instances)
490
+
491
+ @abc.abstractmethod
492
+ def construct_requests(self, doc, ctx, **kwargs):
493
+ """Uses RequestFactory to construct Requests and returns an iterable of
494
+ Requests which will be sent to the LM.
495
+
496
+ :param doc:
497
+ The document as returned from training_docs, validation_docs, or test_docs.
498
+ :param ctx: str
499
+ The context string, generated by fewshot_context. This includes the natural
500
+ language description, as well as the few shot examples, and the question
501
+ part of the document for `doc`.
502
+ :param doc_idx: int
503
+ The index of a document within `self.test_docs()` or `self.validation_docs()`,
504
+ whichever is the main split used.
505
+ :param repeats: int
506
+ TODO: update this docstring
507
+ The number of times each instance in a dataset is inferred on. Defaults to 1,
508
+ can be increased for techniques like majority voting.
509
+ """
510
+ pass
511
+
512
+ @abc.abstractmethod
513
+ def process_results(self, doc, results):
514
+ """Take a single document and the LM results and evaluates, returning a
515
+ dict where keys are the names of submetrics and values are the values of
516
+ the metric for that one document
517
+
518
+ :param doc:
519
+ The document as returned from training_docs, validation_docs, or test_docs.
520
+ :param results:
521
+ The results of the requests created in construct_requests.
522
+ """
523
+ pass
524
+
525
+ @abc.abstractmethod
526
+ def aggregation(self):
527
+ """
528
+ :returns: {str: [metric_score] -> float}
529
+ A dictionary where keys are the names of submetrics and values are
530
+ functions that aggregate a list of metric scores
531
+ """
532
+ pass
533
+
534
+ @abc.abstractmethod
535
+ def higher_is_better(self):
536
+ """
537
+ :returns: {str: bool}
538
+ A dictionary where keys are the names of submetrics and values are
539
+ whether a higher value of the submetric is better
540
+ """
541
+ pass
542
+
543
+ def get_config(self, key: str) -> Any:
544
+ return getattr(self._config, key, None)
545
+
546
+ @classmethod
547
+ def count_bytes(cls, doc):
548
+ """Used for byte-level perplexity metrics in rolling loglikelihood"""
549
+ return len(doc.encode("utf-8"))
550
+
551
+ @classmethod
552
+ def count_words(cls, doc):
553
+ """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
554
+ return len(re.split(r"\s+", doc))
555
+
556
+ @utils.positional_deprecated
557
+ def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
558
+ """Returns a fewshot context string that is made up of a prepended description
559
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
560
+
561
+ :param doc: str
562
+ The document as returned from training_docs, validation_docs, or test_docs.
563
+ :param num_fewshot: int
564
+ The number of fewshot examples to provide in the returned context string.
565
+ :param rnd: random.Random
566
+ The pseudo-random number generator used to randomly sample examples.
567
+ WARNING: This is currently a required arg although it's optionalized with a default `None`.
568
+ :param description: str
569
+ The task's description that will be prepended to the fewshot examples.
570
+ :returns: str
571
+ The fewshot context.
572
+ """
573
+ if rnd is None:
574
+ if self.fewshot_rnd is not None:
575
+ rnd = self.fewshot_rnd
576
+ else:
577
+ raise ValueError(
578
+ "A `random.Random` generator argument must be provided to `rnd`"
579
+ )
580
+
581
+ description = description if description else ""
582
+
583
+ if num_fewshot == 0:
584
+ labeled_examples = ""
585
+ else:
586
+ # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
587
+ if self.has_training_docs():
588
+ fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
589
+ else:
590
+ if self._fewshot_docs is None:
591
+ self._fewshot_docs = list(
592
+ self.validation_docs()
593
+ if self.has_validation_docs()
594
+ else self.test_docs()
595
+ )
596
+
597
+ fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
598
+
599
+ # get rid of the doc that's the one we're evaluating, if it's in the fewshot
600
+ fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
601
+
602
+ labeled_examples = (
603
+ "\n\n".join(
604
+ [
605
+ self.doc_to_text(doc) + self.doc_to_target(doc)
606
+ for doc in fewshotex
607
+ ]
608
+ )
609
+ + "\n\n"
610
+ )
611
+
612
+ example = self.doc_to_text(doc)
613
+ return description + labeled_examples + example
614
+
615
+ def apply_filters(self) -> Optional[List[Instance]]:
616
+ """Iterates over FilterEnsembles and applies them to instances"""
617
+ if hasattr(self, "_filters"):
618
+ for f in self._filters:
619
+ f.apply(self._instances)
620
+ else:
621
+ eval_logger.warning("No filter defined, passing through instances")
622
+ return self._instances
623
+
624
+ def dump_config(self) -> dict:
625
+ """Returns the config as a dictionary."""
626
+ # TODO: this should only return the overrides applied to a non-YAML task's configuration.
627
+ # (num_fewshot)
628
+ return self.config.to_dict()
629
+
630
+ def set_config(self, key: str, value: Any, update: bool = False) -> None:
631
+ """Set or update the configuration for a given key."""
632
+ if key is None:
633
+ raise ValueError("Key must be provided.")
634
+
635
+ if update:
636
+ current_value = getattr(self._config, key, {})
637
+ if not isinstance(current_value, dict):
638
+ raise TypeError(
639
+ f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
640
+ )
641
+ current_value.update(value)
642
+ else:
643
+ setattr(self._config, key, value)
644
+
645
+ def override_metric(self, metric_name: str) -> None:
646
+ """
647
+ Override the default metrics used for evaluation with custom metrics.
648
+
649
+ Parameters:
650
+ - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
651
+ """
652
+ (
653
+ self._metric_fn_list,
654
+ self._aggregation_list,
655
+ self._metric_fn_kwargs,
656
+ self._higher_is_better,
657
+ ) = ({}, {}, {}, {})
658
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
659
+ self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
660
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
661
+ self._metric_fn_kwargs[metric_name] = {}
662
+ if not isinstance(self, ConfigurableTask):
663
+ self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
664
+ self.aggregation = lambda: {
665
+ metric_name: get_metric_aggregation(metric_name)
666
+ }
667
+ setattr(self._config, "metric_list", [{"metric": metric_name}])
668
+ setattr(self._config, "process_results", None)
669
+
670
+ def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
671
+ self.fewshot_rnd = random.Random(seed)
672
+ if hasattr(self, "sampler"):
673
+ self.sampler.rnd = self.fewshot_rnd
674
+
675
+ @property
676
+ def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
677
+ if self.has_test_docs():
678
+ return self.test_docs()
679
+ elif self.has_validation_docs():
680
+ return self.validation_docs()
681
+ else:
682
+ raise ValueError(
683
+ f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
684
+ )
685
+
686
+ def doc_iterator(
687
+ self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
688
+ ) -> Iterator[Tuple[int, Any]]:
689
+ limit = int(limit) if limit else None
690
+ doc_iterator = utils.create_iterator(
691
+ enumerate(self.eval_docs),
692
+ rank=int(rank),
693
+ limit=limit,
694
+ world_size=int(world_size),
695
+ )
696
+ return doc_iterator
697
+
698
+
699
+ class ConfigurableTask(Task):
700
+ VERSION = "Yaml"
701
+ OUTPUT_TYPE = None
702
+ CONFIG = None
703
+
704
+ def __init__(
705
+ self,
706
+ data_dir=None,
707
+ cache_dir=None,
708
+ download_mode=None,
709
+ config: Optional[dict] = None,
710
+ ) -> None: # TODO no super() call here
711
+ # Get pre-configured attributes
712
+ self._config = self.CONFIG
713
+
714
+ # Use new configurations if there was no preconfiguration
715
+ if self.config is None:
716
+ self._config = TaskConfig(**config)
717
+ # Overwrite configs
718
+ else:
719
+ if config is not None:
720
+ self._config.__dict__.update(config)
721
+
722
+ if self.config is None:
723
+ raise ValueError(
724
+ "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
725
+ )
726
+
727
+ if isinstance(self.config.metadata, dict):
728
+ if "version" in self.config.metadata:
729
+ self.VERSION = self.config.metadata["version"]
730
+
731
+ if self.config.output_type is not None:
732
+ if self.config.output_type not in ALL_OUTPUT_TYPES:
733
+ raise ValueError(
734
+ f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
735
+ )
736
+ self.OUTPUT_TYPE = self.config.output_type
737
+
738
+ if self.config.doc_to_image is not None:
739
+ # mark the task as requiring multimodality.
740
+ self.MULTIMODAL = True
741
+
742
+ if self.config.doc_to_audio:
743
+ # mark the task as requiring multimodality.
744
+ self.MULTIMODAL = True
745
+
746
+ if self.config.unsafe_code is not False:
747
+ self.UNSAFE_CODE = True
748
+
749
+ if self.config.dataset_path is not None:
750
+ self.DATASET_PATH = self.config.dataset_path
751
+
752
+ if self.config.dataset_name is not None:
753
+ self.DATASET_NAME = self.config.dataset_name
754
+
755
+ self._metric_fn_list = {}
756
+ self._metric_fn_kwargs = {}
757
+ self._aggregation_list = {}
758
+ self._higher_is_better = {}
759
+
760
+ if self.config.metric_list is None:
761
+ # TODO: handle this in TaskConfig.__post_init__ ?
762
+ _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
763
+
764
+ for metric_name in _metric_list:
765
+ self._metric_fn_list[metric_name] = get_metric(metric_name)
766
+ self._metric_fn_kwargs[metric_name] = {}
767
+ self._aggregation_list[metric_name] = get_metric_aggregation(
768
+ metric_name
769
+ )
770
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
771
+ else:
772
+ for metric_config in self.config.metric_list:
773
+ if "metric" not in metric_config:
774
+ raise ValueError(
775
+ "'metric' key not provided for an entry in 'metric_list', must be specified!"
776
+ )
777
+ metric_name = metric_config["metric"]
778
+ kwargs = {
779
+ key: metric_config[key]
780
+ for key in metric_config
781
+ if key
782
+ not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
783
+ }
784
+ hf_evaluate_metric = (
785
+ "hf_evaluate" in metric_config
786
+ and metric_config["hf_evaluate"] is True
787
+ )
788
+
789
+ if self.config.process_results is not None:
790
+ self._metric_fn_list[metric_name] = None
791
+ self._metric_fn_kwargs[metric_name] = {}
792
+ elif callable(metric_name):
793
+ metric_fn = metric_name.__call__
794
+ metric_name = metric_name.__name__
795
+ self._metric_fn_list[metric_name] = metric_fn
796
+ self._metric_fn_kwargs[metric_name] = kwargs
797
+ else:
798
+ self._metric_fn_list[metric_name] = get_metric(
799
+ metric_name, hf_evaluate_metric
800
+ )
801
+ self._metric_fn_kwargs[metric_name] = kwargs
802
+
803
+ if "aggregation" in metric_config:
804
+ agg_name = metric_config["aggregation"]
805
+ if isinstance(agg_name, str):
806
+ self._aggregation_list[metric_name] = get_aggregation(agg_name)
807
+ elif callable(agg_name): # noqa: E721
808
+ self._aggregation_list[metric_name] = metric_config[
809
+ "aggregation"
810
+ ]
811
+ else:
812
+ INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
813
+ metric_agg = get_metric_aggregation(metric_name)
814
+ eval_logger.warning(
815
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
816
+ f"using default "
817
+ f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
818
+ )
819
+ self._aggregation_list[metric_name] = metric_agg
820
+
821
+ if "higher_is_better" in metric_config:
822
+ self._higher_is_better[metric_name] = metric_config[
823
+ "higher_is_better"
824
+ ]
825
+ else:
826
+ eval_logger.warning(
827
+ f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
828
+ f"using default "
829
+ f"higher_is_better={is_higher_better(metric_name)}"
830
+ )
831
+ self._higher_is_better[metric_name] = is_higher_better(metric_name)
832
+
833
+ self.download(self.config.dataset_kwargs)
834
+ self._training_docs = None
835
+ self._fewshot_docs = None
836
+
837
+ if self.config.filter_list is not None:
838
+ self._filters = []
839
+ for filter_config in self.config.filter_list:
840
+ filter_name = filter_config["name"]
841
+ filter_functions = filter_config["filter"]
842
+ components = []
843
+ for function in filter_functions:
844
+ kwargs = {
845
+ key: function[key] for key in function if key != "function"
846
+ }
847
+ components.append([function["function"], kwargs])
848
+ filter_pipeline = build_filter_ensemble(filter_name, components)
849
+ self._filters.append(filter_pipeline)
850
+ else:
851
+ # TODO: handle repeats in a more general way rather than just discarding
852
+ eval_logger.debug(
853
+ "No custom filters defined. Using default 'take_first' filter for handling repeats."
854
+ )
855
+ self._filters = [build_filter_ensemble("none", [["take_first", None]])]
856
+
857
+ if self.config.use_prompt is not None:
858
+ eval_logger.info(f"loading prompt {self.config.use_prompt}")
859
+ self.prompt = get_prompt(
860
+ self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
861
+ )
862
+ else:
863
+ self.prompt = None
864
+
865
+ if self.fewshot_docs() is not None:
866
+ self.fewshot_rnd = (
867
+ random.Random()
868
+ ) # setting with no seed, to be overridden at a later time
869
+ config_sampler: Union[str, Callable] = (
870
+ self.config.fewshot_config.get("sampler", "default")
871
+ if self.config.fewshot_config
872
+ else "default"
873
+ )
874
+ if isinstance(config_sampler, str):
875
+ self.sampler = samplers.get_sampler(config_sampler)(
876
+ list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
877
+ )
878
+ elif callable(config_sampler) and issubclass(
879
+ config_sampler, samplers.ContextSampler
880
+ ):
881
+ self.sampler = config_sampler(
882
+ docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
883
+ )
884
+ else:
885
+ raise TypeError(
886
+ f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
887
+ f"not {type(config_sampler)}"
888
+ )
889
+
890
+ self.task_docs = self.eval_docs
891
+
892
+ # Test One Doc
893
+ self.features = list(self.task_docs.features.keys())
894
+ self.multiple_input = 0
895
+ self.multiple_target = 0
896
+ test_doc = self.task_docs[0]
897
+ test_text = self.doc_to_text(test_doc)
898
+ test_target = self.doc_to_target(test_doc)
899
+
900
+ if self.config.doc_to_choice is not None:
901
+ test_choice = self.doc_to_choice(test_doc)
902
+ if not isinstance(test_choice, list):
903
+ eval_logger.error("doc_to_choice must return list")
904
+ else:
905
+ num_choice = len(test_choice)
906
+
907
+ if isinstance(test_text, int):
908
+ self.multiple_input = num_choice
909
+ else:
910
+ test_choice = None
911
+
912
+ if isinstance(test_target, list):
913
+ self.multiple_target = len(test_target)
914
+ else:
915
+ if (isinstance(test_target, int)) and (test_choice is not None):
916
+ test_target = test_choice[test_target]
917
+ else:
918
+ test_target = str(test_target)
919
+
920
+ if test_choice is not None:
921
+ check_choices = test_choice
922
+ else:
923
+ check_choices = [test_target]
924
+ if self.config.doc_to_choice is not None:
925
+ for choice in check_choices:
926
+ choice_has_whitespace = True if choice[0].isspace() else False
927
+ delimiter_has_whitespace = (
928
+ True
929
+ if self.config.target_delimiter.rstrip()
930
+ != self.config.target_delimiter
931
+ else False
932
+ )
933
+
934
+ if delimiter_has_whitespace and choice_has_whitespace:
935
+ eval_logger.debug(
936
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
937
+ )
938
+ elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
939
+ eval_logger.debug(
940
+ f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
941
+ )
942
+
943
+ def download(
944
+ self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
945
+ ) -> None:
946
+ if isinstance(self.config.custom_dataset, Callable):
947
+ eval_logger.warning(
948
+ f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
949
+ + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
950
+ )
951
+ self.dataset = self.config.custom_dataset(
952
+ **(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
953
+ )
954
+ else:
955
+ self.dataset = datasets.load_dataset(
956
+ path=self.DATASET_PATH,
957
+ name=self.DATASET_NAME,
958
+ **dataset_kwargs if dataset_kwargs is not None else {},
959
+ )
960
+
961
+ def has_training_docs(self) -> bool:
962
+ if self.config.training_split is not None:
963
+ return True
964
+ else:
965
+ return False
966
+
967
+ def has_validation_docs(self) -> bool:
968
+ if self.config.validation_split is not None:
969
+ return True
970
+ else:
971
+ return False
972
+
973
+ def has_test_docs(self) -> bool:
974
+ if self.config.test_split is not None:
975
+ return True
976
+ else:
977
+ return False
978
+
979
+ def training_docs(self) -> datasets.Dataset:
980
+ if self.has_training_docs():
981
+ if self.config.process_docs is not None:
982
+ return self.config.process_docs(
983
+ self.dataset[self.config.training_split]
984
+ )
985
+ return self.dataset[self.config.training_split]
986
+
987
+ def validation_docs(self) -> datasets.Dataset:
988
+ if self.has_validation_docs():
989
+ if self.config.process_docs is not None:
990
+ return self.config.process_docs(
991
+ self.dataset[self.config.validation_split]
992
+ )
993
+ return self.dataset[self.config.validation_split]
994
+
995
+ def test_docs(self) -> datasets.Dataset:
996
+ if self.has_test_docs():
997
+ if self.config.process_docs is not None:
998
+ return self.config.process_docs(self.dataset[self.config.test_split])
999
+ return self.dataset[self.config.test_split]
1000
+
1001
+ def fewshot_docs(self):
1002
+ if self.config.fewshot_split is not None:
1003
+ if self.config.process_docs is not None:
1004
+ return self.config.process_docs(self.dataset[self.config.fewshot_split])
1005
+ return self.dataset[self.config.fewshot_split]
1006
+ elif (
1007
+ self.config.fewshot_config is not None
1008
+ and self.config.fewshot_config.get("samples", None) is not None
1009
+ ):
1010
+ if isinstance(self.config.fewshot_config["samples"], list):
1011
+ return self.config.fewshot_config["samples"]
1012
+ elif callable(self.config.fewshot_config["samples"]):
1013
+ return self.config.fewshot_config["samples"]()
1014
+ else:
1015
+ raise Exception(
1016
+ "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
1017
+ )
1018
+ else:
1019
+ if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
1020
+ eval_logger.warning(
1021
+ f"[Task: {self.config.task}] "
1022
+ "num_fewshot > 0 but fewshot_split is None. "
1023
+ "using preconfigured rule."
1024
+ )
1025
+ return super().fewshot_docs()
1026
+
1027
+ @staticmethod
1028
+ def append_target_question(
1029
+ labeled_examples: List[Dict[str, str]],
1030
+ question: str,
1031
+ fewshot_as_multiturn: bool = False,
1032
+ gen_prefix: Optional[str] = None,
1033
+ ) -> None:
1034
+ """Adds a target question to the labeled examples list.
1035
+ If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
1036
+ Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
1037
+ """
1038
+ if not fewshot_as_multiturn:
1039
+ # if no messages or last message is system, append as new user entry
1040
+ if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
1041
+ labeled_examples.append({"role": "user", "content": question})
1042
+ # if last message is user, append to it to avoid two user messages in a row
1043
+ else:
1044
+ labeled_examples[-1]["content"] += question
1045
+ else:
1046
+ # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
1047
+ labeled_examples.append({"role": "user", "content": question})
1048
+ if gen_prefix:
1049
+ labeled_examples.append({"role": "assistant", "content": gen_prefix})
1050
+
1051
+ @utils.positional_deprecated
1052
+ def fewshot_context(
1053
+ self,
1054
+ doc: dict,
1055
+ num_fewshot: int,
1056
+ system_instruction: Optional[str] = None,
1057
+ apply_chat_template: bool = False,
1058
+ fewshot_as_multiturn: bool = False,
1059
+ chat_template: Optional[Callable] = None,
1060
+ gen_prefix: Optional[str] = None,
1061
+ ) -> Union[str, List[str]]:
1062
+ """Returns a fewshot context string that is made up of a prepended description
1063
+ (if provided), the `num_fewshot` number of examples, and an appended prompt example.
1064
+
1065
+ :param doc: str
1066
+ The document as returned from training_docs, validation_docs, or test_docs.
1067
+ :param num_fewshot: int
1068
+ The number of fewshot examples to provide in the returned context string.
1069
+ :param system_instruction: str
1070
+ System instruction to be applied to the prompt.
1071
+ :param apply_chat_template: bool
1072
+ Whether to apply the chat template to the fewshot context.
1073
+ :param fewshot_as_multiturn: bool
1074
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
1075
+ :param chat_template:
1076
+ callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
1077
+ :param gen_prefix:
1078
+ String to append after the <|assistant|> token.
1079
+ :returns: str
1080
+ The fewshot context.
1081
+ """
1082
+ if apply_chat_template:
1083
+ labeled_examples = []
1084
+ else:
1085
+ labeled_examples = ""
1086
+
1087
+ # get task description
1088
+ if description := self.config.description:
1089
+ description = utils.apply_template(self.config.description, doc)
1090
+
1091
+ # create system prompt based on the provided system instruction and description
1092
+ if system_instruction is not None and description:
1093
+ system_prompt = (
1094
+ f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
1095
+ )
1096
+ elif system_instruction is not None:
1097
+ system_prompt = system_instruction
1098
+ elif description:
1099
+ system_prompt = description
1100
+ else:
1101
+ system_prompt = ""
1102
+
1103
+ # add system prompt if specified
1104
+ if system_prompt:
1105
+ if apply_chat_template:
1106
+ labeled_examples.append({"role": "system", "content": system_prompt})
1107
+ else:
1108
+ labeled_examples = system_prompt
1109
+ # if few-shot - append examples after the system prompt
1110
+ if num_fewshot > 0:
1111
+ if apply_chat_template:
1112
+ labeled_examples.extend(
1113
+ self.sampler.get_chat_context(
1114
+ doc,
1115
+ num_fewshot,
1116
+ fewshot_as_multiturn,
1117
+ gen_prefix=gen_prefix,
1118
+ )
1119
+ )
1120
+ else:
1121
+ labeled_examples += self.sampler.get_context(
1122
+ doc, num_fewshot, gen_prefix=gen_prefix
1123
+ )
1124
+
1125
+ example = self.doc_to_text(doc)
1126
+ if apply_chat_template:
1127
+ if self.multiple_input:
1128
+ # TODO: append prefill?
1129
+ if not labeled_examples:
1130
+ return ""
1131
+ return chat_template(labeled_examples)
1132
+ if isinstance(example, str):
1133
+ self.append_target_question(
1134
+ labeled_examples,
1135
+ example,
1136
+ fewshot_as_multiturn,
1137
+ gen_prefix=gen_prefix,
1138
+ )
1139
+ # for loglikelihood create a list of questions with appended choices
1140
+ elif isinstance(example, list):
1141
+ labeled_examples_list = []
1142
+ # copy chat history for each example and append the answer
1143
+ for ex in example:
1144
+ chat = deepcopy(labeled_examples)
1145
+ self.append_target_question(
1146
+ chat,
1147
+ ex,
1148
+ fewshot_as_multiturn,
1149
+ gen_prefix=gen_prefix,
1150
+ )
1151
+ # TODO: append prefill?
1152
+ labeled_examples_list.append(
1153
+ chat_template(
1154
+ chat,
1155
+ add_generation_prompt=False if gen_prefix else True,
1156
+ )
1157
+ )
1158
+ return labeled_examples_list
1159
+ # if example is an integer, append the choice or convert to string
1160
+ elif isinstance(example, int):
1161
+ if self.config.doc_to_choice is not None:
1162
+ choices = self.doc_to_choice(doc)
1163
+ self.append_target_question(
1164
+ labeled_examples,
1165
+ choices[example],
1166
+ fewshot_as_multiturn,
1167
+ gen_prefix=gen_prefix,
1168
+ )
1169
+ else:
1170
+ self.append_target_question(
1171
+ labeled_examples,
1172
+ str(example),
1173
+ fewshot_as_multiturn,
1174
+ gen_prefix=gen_prefix,
1175
+ )
1176
+ # return lm.apply_chat_template(labeled_examples)
1177
+ return chat_template(
1178
+ labeled_examples,
1179
+ add_generation_prompt=False if gen_prefix else True,
1180
+ )
1181
+ else:
1182
+ prefix = (
1183
+ self.config.target_delimiter + gen_prefix
1184
+ if gen_prefix is not None
1185
+ else ""
1186
+ )
1187
+ if self.multiple_input:
1188
+ return labeled_examples
1189
+ if isinstance(example, str):
1190
+ return labeled_examples + example + prefix
1191
+ elif isinstance(example, list):
1192
+ return [labeled_examples + ex + prefix for ex in example]
1193
+ elif isinstance(example, int):
1194
+ if self.config.doc_to_choice is not None:
1195
+ choices = self.doc_to_choice(doc)
1196
+ return labeled_examples + choices[example] + prefix
1197
+ else:
1198
+ return labeled_examples + str(example) + prefix
1199
+
1200
+ def apply_filters(self) -> Optional[List[Instance]]:
1201
+ """Iterates over FilterEnsembles and applies them to instances"""
1202
+ if hasattr(self, "_filters"):
1203
+ for f in self._filters:
1204
+ f.apply(self._instances)
1205
+ else:
1206
+ eval_logger.warning("No filter defined, passing through instances")
1207
+ return self._instances
1208
+
1209
+ def should_decontaminate(self):
1210
+ return self.config.should_decontaminate
1211
+
1212
+ def doc_to_decontamination_query(self, doc: dict):
1213
+ if self.config.should_decontaminate:
1214
+ if self.config.doc_to_decontamination_query is None:
1215
+ return self.doc_to_text(doc)
1216
+ else:
1217
+ doc_to_decontamination_query = self.config.doc_to_decontamination_query
1218
+ if doc_to_decontamination_query in self.features:
1219
+ return doc[doc_to_decontamination_query]
1220
+ elif callable(doc_to_decontamination_query):
1221
+ return doc_to_decontamination_query(doc)
1222
+ else:
1223
+ return ast.literal_eval(
1224
+ utils.apply_template(
1225
+ self.config.doc_to_decontamination_query, doc
1226
+ )
1227
+ )
1228
+
1229
+ def _process_doc(self, doc: dict) -> dict:
1230
+ """
1231
+ Override this to process (detokenize, strip, replace, etc.) individual
1232
+ documents. This can be used in a map over documents of a data split.
1233
+ E.g. `map(self._process_doc, self.dataset["validation"])`
1234
+
1235
+ :return: dict
1236
+ The processed version of the specified `doc`.
1237
+ """
1238
+ return doc
1239
+
1240
+ def doc_to_text(self, doc, doc_to_text=None):
1241
+ if self.prompt is not None:
1242
+ doc_to_text = self.prompt
1243
+ elif doc_to_text is not None:
1244
+ doc_to_text = doc_to_text
1245
+ else:
1246
+ doc_to_text = self.config.doc_to_text
1247
+
1248
+ if isinstance(doc_to_text, int):
1249
+ return doc_to_text
1250
+ elif isinstance(doc_to_text, str):
1251
+ if doc_to_text in self.features:
1252
+ # if self.config.doc_to_choice is not None:
1253
+ # return self.doc_to_choice(doc)[doc[doc_to_text]]
1254
+ # else:
1255
+ return doc[doc_to_text]
1256
+ else:
1257
+ text_string = utils.apply_template(doc_to_text, doc)
1258
+ if text_string.isdigit() and self._config.doc_to_choice is not None:
1259
+ return ast.literal_eval(text_string)
1260
+ else:
1261
+ return text_string
1262
+ elif callable(doc_to_text):
1263
+ return doc_to_text(doc)
1264
+ # Used when applying a Promptsource template
1265
+ elif hasattr(doc_to_text, "apply"):
1266
+ applied_prompt = doc_to_text.apply(doc)
1267
+ if len(applied_prompt) == 2:
1268
+ return applied_prompt[0]
1269
+ else:
1270
+ eval_logger.warning("Applied prompt returns empty string")
1271
+ return self.config.fewshot_delimiter
1272
+ else:
1273
+ print(type(doc_to_text))
1274
+ raise TypeError
1275
+
1276
+ def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
1277
+ if self.prompt is not None:
1278
+ doc_to_target = self.prompt
1279
+ elif doc_to_target is not None:
1280
+ doc_to_target = doc_to_target
1281
+ else:
1282
+ doc_to_target = self.config.doc_to_target
1283
+
1284
+ if isinstance(doc_to_target, int):
1285
+ return doc_to_target
1286
+ elif isinstance(doc_to_target, str):
1287
+ if doc_to_target in self.features:
1288
+ # if self.config.doc_to_choice is not None:
1289
+ # return self.doc_to_choice(doc)[doc[doc_to_target]]
1290
+ # else:
1291
+ return doc[doc_to_target]
1292
+ else:
1293
+ target_string = utils.apply_template(doc_to_target, doc)
1294
+ if target_string.isdigit() and self._config.doc_to_choice is not None:
1295
+ return ast.literal_eval(target_string)
1296
+ elif (
1297
+ len(target_string) >= 2
1298
+ and (target_string[0] == "[")
1299
+ and (target_string[-1] == "]")
1300
+ ):
1301
+ try:
1302
+ return ast.literal_eval(target_string)
1303
+ except (SyntaxError, ValueError):
1304
+ return target_string
1305
+ else:
1306
+ return target_string
1307
+ elif isinstance(doc_to_target, list):
1308
+ return doc_to_target
1309
+ elif callable(doc_to_target):
1310
+ return doc_to_target(doc)
1311
+ # Used when applying a Promptsource template
1312
+ elif hasattr(doc_to_target, "apply"):
1313
+ applied_prompt = doc_to_target.apply(doc)
1314
+ if len(applied_prompt) == 2:
1315
+ return applied_prompt[1]
1316
+ else:
1317
+ eval_logger.warning("Applied prompt returns empty string")
1318
+ return self.config.fewshot_delimiter
1319
+ else:
1320
+ raise TypeError
1321
+
1322
+ def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
1323
+ if self.prompt is not None:
1324
+ doc_to_choice = self.prompt
1325
+ elif doc_to_choice is not None:
1326
+ doc_to_choice = doc_to_choice
1327
+ elif self.config.doc_to_choice is None:
1328
+ eval_logger.error("doc_to_choice was called but not set in config")
1329
+ else:
1330
+ doc_to_choice = self.config.doc_to_choice
1331
+
1332
+ if isinstance(doc_to_choice, str):
1333
+ if doc_to_choice in self.features:
1334
+ return doc[doc_to_choice]
1335
+ else:
1336
+ return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
1337
+ elif isinstance(doc_to_choice, list):
1338
+ return doc_to_choice
1339
+ elif isinstance(doc_to_choice, dict):
1340
+ return list(doc_to_choice.values())
1341
+ elif callable(doc_to_choice):
1342
+ return doc_to_choice(doc)
1343
+ elif hasattr(doc_to_choice, "get_answer_choices_list"):
1344
+ return doc_to_choice.get_answer_choices_list(doc)
1345
+ else:
1346
+ raise TypeError
1347
+
1348
+ def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
1349
+ if doc_to_image is not None:
1350
+ doc_to_image = doc_to_image
1351
+ elif self.config.doc_to_image is not None:
1352
+ doc_to_image = self.config.doc_to_image
1353
+ else:
1354
+ return None
1355
+
1356
+ if isinstance(doc_to_image, list):
1357
+ image_feature = [
1358
+ self.doc_to_image(doc, feature) for feature in doc_to_image
1359
+ ]
1360
+ return [feature for feature in image_feature if feature is not None]
1361
+ elif isinstance(doc_to_image, str):
1362
+ if doc_to_image in self.features:
1363
+ return doc[doc_to_image]
1364
+ else:
1365
+ return ast.literal_eval(utils.apply_template(doc_to_image, doc))
1366
+ elif callable(doc_to_image):
1367
+ return doc_to_image(doc)
1368
+ else:
1369
+ return None
1370
+
1371
+ def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
1372
+ if doc_to_audio is not None:
1373
+ doc_to_audio = doc_to_audio
1374
+ elif self.config.doc_to_audio is not None:
1375
+ doc_to_audio = self.config.doc_to_audio
1376
+ else:
1377
+ return None
1378
+
1379
+ if isinstance(doc_to_audio, list):
1380
+ audio_feature = [
1381
+ self.doc_to_audio(doc, feature) for feature in doc_to_audio
1382
+ ]
1383
+ return [feature for feature in audio_feature if feature is not None]
1384
+ elif isinstance(doc_to_audio, str):
1385
+ if doc_to_audio in self.features:
1386
+ return doc[doc_to_audio]
1387
+ else:
1388
+ return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
1389
+ elif callable(doc_to_audio):
1390
+ return doc_to_audio(doc)
1391
+ else:
1392
+ return None
1393
+
1394
+ def doc_to_prefix(self, doc):
1395
+ if (gen_prefix := self.config.gen_prefix) is not None:
1396
+ if gen_prefix in self.features:
1397
+ return doc[gen_prefix]
1398
+ else:
1399
+ return utils.apply_template(gen_prefix, doc)
1400
+ return None
1401
+
1402
+ def construct_requests(
1403
+ self, doc: dict, ctx: str, **kwargs
1404
+ ) -> Union[List[Instance], Instance]:
1405
+ apply_chat_template = kwargs.pop("apply_chat_template", False)
1406
+ chat_template: Callable | None = kwargs.pop("chat_template", None)
1407
+
1408
+ aux_arguments = None
1409
+
1410
+ if self.OUTPUT_TYPE == "loglikelihood":
1411
+ arguments = (ctx, self.doc_to_target(doc))
1412
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1413
+ arguments = (self.doc_to_target(doc),)
1414
+ elif self.OUTPUT_TYPE == "multiple_choice":
1415
+ choices = self.doc_to_choice(doc)
1416
+ target_delimiter = self.config.target_delimiter
1417
+ if apply_chat_template:
1418
+ target_delimiter = ""
1419
+ if self.multiple_input:
1420
+ # If there are multiple inputs, choices are placed in the ctx
1421
+ # apply chat_template to choices if apply_chat_template
1422
+ cont = self.doc_to_target(doc)
1423
+
1424
+ arguments = [
1425
+ (
1426
+ ctx
1427
+ + (
1428
+ chat_template([{"role": "user", "content": choice}])
1429
+ if apply_chat_template
1430
+ else choice
1431
+ ),
1432
+ f"{target_delimiter}{cont}",
1433
+ )
1434
+ for choice in choices
1435
+ ]
1436
+ else:
1437
+ # Otherwise they are placed in the continuation
1438
+ arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
1439
+
1440
+ # TODO: we should raise a warning telling users this will at most ~2x runtime.
1441
+ if "acc_mutual_info" in self._metric_fn_list.keys():
1442
+ # if we are calculating multiple choice accuracy
1443
+ # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
1444
+
1445
+ # here mutual info refers to calculating
1446
+ # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
1447
+ # in other words normalizing by subtracting the unconditional logprob of each choice.
1448
+ aux_arguments = [("", f"{choice}") for choice in choices]
1449
+
1450
+ arguments.extend(aux_arguments)
1451
+
1452
+ elif self.OUTPUT_TYPE == "generate_until":
1453
+ arguments = (ctx, deepcopy(self.config.generation_kwargs))
1454
+
1455
+ multimodal_arg = {}
1456
+ if (
1457
+ self.config.doc_to_image
1458
+ ): # TODO: ensure that non-multimodal tasks aren't getting visual args
1459
+ multimodal_arg = {
1460
+ **multimodal_arg,
1461
+ **{"visual": self.doc_to_image(doc)},
1462
+ }
1463
+
1464
+ if (
1465
+ self.config.doc_to_audio
1466
+ ): # TODO: ensure that non-multimodal tasks aren't getting audio args
1467
+ multimodal_arg = {
1468
+ **multimodal_arg,
1469
+ **{"audio": self.doc_to_audio(doc)},
1470
+ }
1471
+
1472
+ if bool(multimodal_arg):
1473
+ if isinstance(arguments, list):
1474
+ arguments = [arg + (multimodal_arg,) for arg in arguments]
1475
+ else:
1476
+ arguments = arguments + (multimodal_arg,)
1477
+
1478
+ if self.OUTPUT_TYPE == "multiple_choice":
1479
+ request_list = [
1480
+ Instance(
1481
+ request_type="loglikelihood",
1482
+ doc=doc,
1483
+ arguments=arg,
1484
+ idx=i,
1485
+ **kwargs,
1486
+ )
1487
+ for i, arg in enumerate(arguments)
1488
+ ]
1489
+
1490
+ return request_list
1491
+
1492
+ return Instance(
1493
+ request_type=self.OUTPUT_TYPE,
1494
+ doc=doc,
1495
+ arguments=arguments,
1496
+ idx=0,
1497
+ **kwargs,
1498
+ )
1499
+
1500
+ def process_results(self, doc, results):
1501
+ if callable(self.config.process_results):
1502
+ return self.config.process_results(doc, results)
1503
+
1504
+ result_dict = {}
1505
+ use_metric = list(self._metric_fn_list.keys())
1506
+ if self.OUTPUT_TYPE == "loglikelihood":
1507
+ results = results[0]
1508
+ ll, is_greedy = results
1509
+ return {
1510
+ **({"perplexity": ll} if "perplexity" in use_metric else {}),
1511
+ **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
1512
+ }
1513
+ elif self.OUTPUT_TYPE == "loglikelihood_rolling":
1514
+ (loglikelihood,) = results
1515
+ _words = self.count_words(self.doc_to_target(doc))
1516
+ _bytes = self.count_bytes(self.doc_to_target(doc))
1517
+ return {
1518
+ **(
1519
+ {"word_perplexity": (loglikelihood, _words)}
1520
+ if "word_perplexity" in use_metric
1521
+ else {}
1522
+ ),
1523
+ **(
1524
+ {"byte_perplexity": (loglikelihood, _bytes)}
1525
+ if "byte_perplexity" in use_metric
1526
+ else {}
1527
+ ),
1528
+ **(
1529
+ {"bits_per_byte": (loglikelihood, _bytes)}
1530
+ if "bits_per_byte" in use_metric
1531
+ else {}
1532
+ ),
1533
+ }
1534
+ elif self.OUTPUT_TYPE == "multiple_choice":
1535
+ lls, is_greedy = zip(*results)
1536
+
1537
+ # retrieve choices in List[str] form, to compute choice lengths, etc.
1538
+ choices = self.doc_to_choice(doc)
1539
+ completion_len = np.array([float(len(i)) for i in choices])
1540
+
1541
+ if (
1542
+ 2 * len(choices) == len(lls)
1543
+ and "acc_mutual_info" in self._metric_fn_list.keys()
1544
+ ):
1545
+ # then we are doing mutual info.
1546
+ # this stores the "dryrun" / unconditional answer loglikelihoods
1547
+ lls_unconditional = lls[1::2]
1548
+ if len(lls_unconditional) != len(choices):
1549
+ raise ValueError
1550
+ # and this stores our "regular" conditional loglikelihoods
1551
+ lls = lls[::2]
1552
+
1553
+ pred = np.argmax(lls)
1554
+ pred_norm = np.argmax(lls / completion_len)
1555
+
1556
+ if self.multiple_input:
1557
+ gold = self.doc_to_text(doc)
1558
+ else:
1559
+ gold = self.doc_to_target(doc)
1560
+
1561
+ gold_index_error = False
1562
+ if isinstance(gold, list):
1563
+ gold = [i if i < len(choices) else -100 for i in gold]
1564
+ if -100 in gold:
1565
+ gold_index_error = True
1566
+ else:
1567
+ if isinstance(gold, int):
1568
+ gold = gold if gold < len(choices) else -100
1569
+ elif isinstance(gold, str):
1570
+ gold = choices.index(gold) if gold in choices else -100
1571
+
1572
+ if gold == -100:
1573
+ gold_index_error = True
1574
+
1575
+ if gold_index_error:
1576
+ eval_logger.warning(
1577
+ f"Label index was not in within range of available choices,"
1578
+ f"Sample:\n\n{doc}\n\n"
1579
+ )
1580
+
1581
+ if self.multiple_target:
1582
+ acc = 1.0 if pred in gold else 0.0
1583
+ acc_norm = 1.0 if pred_norm in gold else 0.0
1584
+ exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
1585
+ else:
1586
+ acc = 1.0 if pred == gold else 0.0
1587
+ acc_norm = 1.0 if pred_norm == gold else 0.0
1588
+ # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
1589
+ exact_match = int(is_greedy[gold]) if gold != -100 else 0
1590
+
1591
+ prob_norm = utils.softmax(lls)
1592
+
1593
+ # TODO use keyword arguments to the metric?
1594
+ # gold, pred, norm stuff, the original lls,
1595
+ result_dict = {
1596
+ **({"acc": acc} if "acc" in use_metric else {}),
1597
+ **({"f1": (gold, pred)} if "f1" in use_metric else {}),
1598
+ **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
1599
+ **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
1600
+ **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
1601
+ **(
1602
+ {"brier_score": (gold, prob_norm)}
1603
+ if "brier_score" in use_metric
1604
+ else {}
1605
+ ),
1606
+ }
1607
+
1608
+ if "acc_mutual_info" in use_metric:
1609
+ lls_mutual_info = [
1610
+ ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
1611
+ ]
1612
+ acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
1613
+ result_dict["acc_mutual_info"] = acc_mutual_info
1614
+
1615
+ elif self.OUTPUT_TYPE == "generate_until":
1616
+ gold = self.doc_to_target(doc)
1617
+ result = results[0]
1618
+ if self.config.doc_to_choice is not None:
1619
+ # If you set doc_to_choice,
1620
+ # it assumes that doc_to_target returns a number.
1621
+ choices = self.doc_to_choice(doc)
1622
+ gold = choices[gold]
1623
+ # we expect multiple_targets to be a list.
1624
+ elif self.multiple_target:
1625
+ gold = list(gold)
1626
+ # TODO: handle this better
1627
+ elif type(gold) is not type(result) and not (
1628
+ "bypass" in self._metric_fn_list.keys() or isinstance(result, list)
1629
+ ):
1630
+ # cast gold to the same type as result
1631
+ gold = type(result)(gold)
1632
+
1633
+ for metric in self._metric_fn_list.keys():
1634
+ if self.multiple_target:
1635
+ # in the case where we have multiple targets,
1636
+ # return true if any are true
1637
+ # TODO: this may break for multipLe_target, non zero-or-1 metrics
1638
+ scores = []
1639
+ if not isinstance(gold, list):
1640
+ # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
1641
+ # print(gold)
1642
+ gold = [gold]
1643
+ if metric == "exact_match":
1644
+ result = [result for _ in range(len(gold))]
1645
+ scores = self._metric_fn_list[metric](
1646
+ references=gold,
1647
+ predictions=result,
1648
+ **self._metric_fn_kwargs[metric],
1649
+ )[metric]
1650
+ result_score = 1.0 if scores > 0.0 else 0.0
1651
+ else:
1652
+ for gold_option in gold:
1653
+ try:
1654
+ result_score = self._metric_fn_list[metric](
1655
+ references=[gold_option],
1656
+ predictions=[result],
1657
+ **self._metric_fn_kwargs[metric],
1658
+ )
1659
+ except (
1660
+ TypeError
1661
+ ): # TODO: this is hacky and I don't want to do it
1662
+ result_score = self._metric_fn_list[metric](
1663
+ [gold_option, result]
1664
+ )
1665
+ if isinstance(result_score, dict):
1666
+ # TODO: this handles the case where HF evaluate returns a dict.
1667
+ result_score = result_score[metric]
1668
+ scores.append(result_score)
1669
+ if any(scores):
1670
+ result_score = 1.0
1671
+ else:
1672
+ result_score = 0.0
1673
+ else:
1674
+ try:
1675
+ result_score = self._metric_fn_list[metric](
1676
+ references=[gold],
1677
+ predictions=[result],
1678
+ **self._metric_fn_kwargs[metric],
1679
+ )
1680
+ except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
1681
+ result_score = self._metric_fn_list[metric]([gold, result])
1682
+ if isinstance(result_score, dict):
1683
+ # TODO: this handles the case where HF evaluate returns a dict.
1684
+ # This allows for multiple metrics to be returned from the same function
1685
+ for k, v in result_score.items():
1686
+ result_dict[k] = v
1687
+ else:
1688
+ result_dict[metric] = result_score
1689
+ else:
1690
+ raise ValueError(
1691
+ f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
1692
+ "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
1693
+ )
1694
+
1695
+ return result_dict
1696
+
1697
+ def aggregation(self) -> dict:
1698
+ return self._aggregation_list
1699
+
1700
+ def higher_is_better(self) -> dict:
1701
+ return self._higher_is_better
1702
+
1703
+ def get_config(self, key: str) -> Any:
1704
+ return getattr(self._config, key, None)
1705
+
1706
+ @property
1707
+ def task_name(self) -> Any:
1708
+ return getattr(self.config, "task", None)
1709
+
1710
+ def __repr__(self):
1711
+ return (
1712
+ f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
1713
+ f"output_type={self.OUTPUT_TYPE},"
1714
+ f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
1715
+ f"num_samples={len(self.eval_docs)})"
1716
+ )
1717
+
1718
+
1719
+ class MultipleChoiceTask(Task):
1720
+ OUTPUT_TYPE = "loglikelihood"
1721
+
1722
+ def doc_to_target(self, doc: dict) -> str:
1723
+ return " " + doc["choices"][doc["gold"]]
1724
+
1725
+ def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
1726
+ # TODO: add mutual info here?
1727
+ return [
1728
+ Instance(
1729
+ request_type="loglikelihood",
1730
+ doc=doc,
1731
+ arguments=(ctx, " {}".format(choice)),
1732
+ idx=i,
1733
+ **kwargs,
1734
+ )
1735
+ for i, choice in enumerate(doc["choices"])
1736
+ ]
1737
+
1738
+ def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
1739
+ results = [
1740
+ res[0] for res in results
1741
+ ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
1742
+ gold = doc["gold"]
1743
+
1744
+ acc = 1.0 if np.argmax(results) == gold else 0.0
1745
+ completion_len = np.array([float(len(i)) for i in doc["choices"]])
1746
+ acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
1747
+
1748
+ return {
1749
+ "acc": acc,
1750
+ "acc_norm": acc_norm,
1751
+ }
1752
+
1753
+ def higher_is_better(self) -> dict:
1754
+ return {
1755
+ "acc": True,
1756
+ "acc_norm": True,
1757
+ }
1758
+
1759
+ def aggregation(self) -> dict:
1760
+ return {
1761
+ "acc": mean,
1762
+ "acc_norm": mean,
1763
+ }
1764
+
1765
+
1766
+ class PerplexityTask(Task):
1767
+ OUTPUT_TYPE = "loglikelihood_rolling"
1768
+
1769
+ def has_training_docs(self) -> bool:
1770
+ return False
1771
+
1772
+ def fewshot_examples(self, k: int, rnd) -> List:
1773
+ if k != 0:
1774
+ raise ValueError(
1775
+ "The number of fewshot examples must be 0 for perplexity tasks."
1776
+ )
1777
+ return []
1778
+
1779
+ def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
1780
+ if num_fewshot != 0:
1781
+ raise ValueError(
1782
+ "The number of fewshot examples must be 0 for perplexity tasks."
1783
+ )
1784
+
1785
+ return ""
1786
+
1787
+ def higher_is_better(self) -> dict:
1788
+ return {
1789
+ "word_perplexity": False,
1790
+ "byte_perplexity": False,
1791
+ "bits_per_byte": False,
1792
+ }
1793
+
1794
+ def doc_to_decontamination_query(self, doc):
1795
+ return doc
1796
+
1797
+ def doc_to_text(self, doc) -> str:
1798
+ return ""
1799
+
1800
+ def doc_to_target(self, doc):
1801
+ return doc
1802
+
1803
+ def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
1804
+ if bool(ctx):
1805
+ raise ValueError
1806
+
1807
+ return Instance(
1808
+ request_type=self.OUTPUT_TYPE,
1809
+ doc=doc,
1810
+ arguments=(self.doc_to_target(doc),),
1811
+ idx=0,
1812
+ **kwargs,
1813
+ )
1814
+
1815
+ def process_results(self, doc: dict, results: Tuple[float]) -> dict:
1816
+ (loglikelihood,) = results
1817
+ words = self.count_words(self.doc_to_target(doc))
1818
+ bytes_ = self.count_bytes(self.doc_to_target(doc))
1819
+ return {
1820
+ "word_perplexity": (loglikelihood, words),
1821
+ "byte_perplexity": (loglikelihood, bytes_),
1822
+ "bits_per_byte": (loglikelihood, bytes_),
1823
+ }
1824
+
1825
+ def aggregation(self) -> dict:
1826
+ return {
1827
+ "word_perplexity": weighted_perplexity,
1828
+ "byte_perplexity": weighted_perplexity,
1829
+ "bits_per_byte": bits_per_byte,
1830
+ }
1831
+
1832
+ @classmethod
1833
+ def count_bytes(cls, doc) -> int:
1834
+ return len(doc.encode("utf-8"))
1835
+
1836
+ @classmethod
1837
+ def count_words(cls, doc) -> int:
1838
+ """Downstream tasks with custom word boundaries should override this!"""
1839
+ return len(re.split(r"\s+", doc))
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/caching/__init__.py ADDED
File without changes
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/caching/cache.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ import os
4
+
5
+ import dill
6
+
7
+
8
+ eval_logger = logging.getLogger(__name__)
9
+
10
+
11
+ MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
12
+
13
+ OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
14
+
15
+
16
+ PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
17
+
18
+ # This should be sufficient for uniqueness
19
+ HASH_INPUT = "EleutherAI-lm-evaluation-harness"
20
+
21
+ HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
22
+
23
+ FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
24
+
25
+
26
+ def load_from_cache(file_name: str, cache: bool = False):
27
+ if not cache:
28
+ return
29
+ try:
30
+ path = f"{PATH}/{file_name}{FILE_SUFFIX}"
31
+
32
+ with open(path, "rb") as file:
33
+ cached_task_dict = dill.loads(file.read())
34
+ return cached_task_dict
35
+
36
+ except Exception:
37
+ eval_logger.debug(f"{file_name} is not cached, generating...")
38
+ pass
39
+
40
+
41
+ def save_to_cache(file_name, obj):
42
+ if not os.path.exists(PATH):
43
+ os.mkdir(PATH)
44
+
45
+ file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
46
+
47
+ eval_logger.debug(f"Saving {file_path} to cache...")
48
+ with open(file_path, "wb") as file:
49
+ file.write(dill.dumps(obj))
50
+
51
+
52
+ # NOTE the "key" param is to allow for flexibility
53
+ def delete_cache(key: str = ""):
54
+ files = os.listdir(PATH)
55
+
56
+ for file in files:
57
+ if file.startswith(key) and file.endswith(FILE_SUFFIX):
58
+ file_path = f"{PATH}/{file}"
59
+ os.unlink(file_path)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/__init__.py ADDED
File without changes
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/archiver.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import io
3
+ import json
4
+ import mmap
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import jsonlines
10
+ import tqdm
11
+ import zstandard
12
+
13
+
14
+ def json_serial(obj: Any) -> str:
15
+ """JSON serializer for objects not serializable by default json code"""
16
+
17
+ if isinstance(obj, (datetime.datetime,)):
18
+ return obj.isoformat()
19
+ raise TypeError("Type %s not serializable" % type(obj))
20
+
21
+
22
+ # Modified version of lm_dataformat Archive for single file.
23
+ class Archive:
24
+ def __init__(self, file_path: str, compression_level: int = 3) -> None:
25
+ self.file_path = file_path
26
+ dir_name = os.path.dirname(file_path)
27
+ if dir_name:
28
+ os.makedirs(dir_name, exist_ok=True)
29
+ self.fh = open(self.file_path, "wb")
30
+ self.cctx = zstandard.ZstdCompressor(level=compression_level)
31
+ self.compressor = self.cctx.stream_writer(self.fh)
32
+
33
+ def add_data(self, data, meta=None) -> None:
34
+ if meta is None:
35
+ meta = {}
36
+ self.compressor.write(
37
+ json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
38
+ "UTF-8"
39
+ )
40
+ + b"\n"
41
+ )
42
+
43
+ def commit(self) -> None:
44
+ self.compressor.flush(zstandard.FLUSH_FRAME)
45
+ self.fh.flush()
46
+ self.fh.close()
47
+
48
+
49
+ # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
50
+ class Reader:
51
+ def __init__(self) -> None:
52
+ pass
53
+
54
+ def read(
55
+ self,
56
+ file,
57
+ get_meta: bool = False,
58
+ autojoin_paragraphs: bool = True,
59
+ para_joiner: str = "\n\n",
60
+ ):
61
+ with open(file, "rb") as fh:
62
+ self.fh = fh
63
+ cctx = zstandard.ZstdDecompressor()
64
+ reader = io.BufferedReader(cctx.stream_reader(fh))
65
+ rdr = jsonlines.Reader(reader)
66
+ for ob in rdr:
67
+ # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
68
+ if isinstance(ob, str):
69
+ assert not get_meta
70
+ yield ob
71
+ continue
72
+
73
+ text = ob["text"]
74
+
75
+ if autojoin_paragraphs and isinstance(text, list):
76
+ text = para_joiner.join(text)
77
+
78
+ if get_meta:
79
+ yield text, (ob["meta"] if "meta" in ob else {})
80
+ else:
81
+ yield text
82
+
83
+
84
+ class TextArchive:
85
+ def __init__(self, file_path, mode: str = "rb+") -> None:
86
+ self.file_path = file_path
87
+ dir_name = os.path.dirname(file_path)
88
+ if dir_name:
89
+ os.makedirs(dir_name, exist_ok=True)
90
+
91
+ if not os.path.exists(file_path):
92
+ Path(file_path).touch()
93
+
94
+ self.fh = open(self.file_path, mode)
95
+
96
+ def add_data(self, data) -> None:
97
+ self.fh.write(data.encode("UTF-8") + b"\n")
98
+
99
+ def commit(self) -> None:
100
+ self.fh.flush()
101
+ self.fh.close()
102
+
103
+
104
+ class TextReader:
105
+ def __init__(self, file_path) -> None:
106
+ self.file_path = file_path
107
+
108
+ # Optimized mmap read with infrequent tqdm updates to maintain speed
109
+ # Tested up to 250MB/s.
110
+ def read_tqdm(self, update_frequency: int = 10000):
111
+ current_file_position = 0
112
+ line_counter = 0
113
+ with (
114
+ open(self.file_path, "r", encoding="utf-8") as fh,
115
+ tqdm.tqdm(
116
+ total=os.path.getsize(self.file_path),
117
+ dynamic_ncols=True,
118
+ unit="byte",
119
+ unit_scale=1,
120
+ ) as progress,
121
+ ):
122
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
123
+ for line in iter(mmap_obj.readline, b""):
124
+ line = line.decode("utf-8")
125
+ line_counter += 1
126
+ if line_counter == update_frequency:
127
+ new_file_pos = mmap_obj.tell()
128
+ bytes_read = new_file_pos - current_file_position
129
+ current_file_position = new_file_pos
130
+ progress.update(bytes_read)
131
+ line_counter = 0
132
+ yield line[:-1]
133
+
134
+ def read_and_tell(self):
135
+ current_file_position = 0
136
+ with open(self.file_path, "r", encoding="utf8") as fh:
137
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
138
+ for line in iter(mmap_obj.readline, b""):
139
+ line = line.decode("utf-8")
140
+ new_file_pos = mmap_obj.tell()
141
+ raw_bytes_read = new_file_pos - current_file_position
142
+ current_file_position = new_file_pos
143
+ yield line[:-1], raw_bytes_read
144
+
145
+ def read(self):
146
+ with open(self.file_path, "r", encoding="utf8") as fh:
147
+ with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
148
+ for line in iter(mmap_obj.readline, b""):
149
+ line = line.decode("utf-8")
150
+ yield line[:-1]
151
+
152
+ def read_slow(self):
153
+ with open(self.file_path, "r", encoding="utf8") as fh:
154
+ while True:
155
+ line = fh.readline()
156
+ if line == -1 or line == "":
157
+ break
158
+ else:
159
+ yield line[:-1]
160
+
161
+
162
+ # Optimized for speed. Decompresses the archive in shell before
163
+ # using the mmap'd TextReader.
164
+ class ZStdTextReader:
165
+ def __init__(self, file) -> None:
166
+ self.file = file
167
+
168
+ def read_tqdm(self):
169
+ decompressed_file = self.file[:-4]
170
+ print("Decompressing file, please wait...")
171
+ os.system(f"zstd -d {self.file}") # linux decompress is faster
172
+ reader = TextReader(decompressed_file)
173
+ yield from reader.read_tqdm()
174
+ os.remove(decompressed_file)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/decontaminate.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import glob
3
+ import json
4
+ import os
5
+ import pickle
6
+ import random
7
+ import time
8
+
9
+ from .archiver import ZStdTextReader
10
+ from .janitor import Janitor, word_ngrams
11
+
12
+
13
+ # Was used for testing the evaluator decoupled from the full logic below
14
+ def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
15
+ simulated_overlap = 0.1
16
+ contaminated = int(len(docs) * simulated_overlap)
17
+ return random.sample(range(len(docs)), contaminated)
18
+
19
+
20
+ # Returns a dictionary containing all overlapping documents in each
21
+ # task. In the standard use case, an overlap occurs when any of the 13-grams
22
+ # found in the task document exist in the training set documents.
23
+ #
24
+ # To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
25
+ # scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
26
+ # files. These should exist in the "ngrams_path" provided to this function.
27
+
28
+
29
+ # Algorithm:
30
+ # 1. Build lookups for each dataset {ngram: list(document_ids)}
31
+ # 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
32
+ # 3. Full scan the 13-grams from the training set against the merged lookup,
33
+ # saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
34
+ # 4. Strip the task_set from the dictionary keys and return
35
+ #
36
+ # We cache the task+set lookups as well as the overlaps.
37
+ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
38
+ # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
39
+
40
+ info_dict_path = os.path.join(ngrams_path, "info.json")
41
+ info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
42
+ ngrams_n_size = info_dict["ngram_size"]
43
+
44
+ janitor = Janitor()
45
+
46
+ # Build lookup for each dataset first in case we use different task combinations later
47
+ print("Building Lookups...")
48
+ start = time.perf_counter()
49
+
50
+ def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
51
+ return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
52
+
53
+ lookups = {}
54
+ duplicates = {} # (task_name, task_set): set(doc_ids)}
55
+ sets_to_decontaminate = len(docs_by_task_set.keys())
56
+
57
+ for (task_name, task_set), docs in docs_by_task_set.items():
58
+ if not os.path.exists(f"data/{task_name}"):
59
+ os.mkdir(f"data/{task_name}")
60
+
61
+ # Check if we've decontaminated this combination before
62
+ overlaps_dump_path = get_overlaps_dump_path(
63
+ task_name, task_set, ngrams_n_size, limit
64
+ )
65
+ if os.path.exists(overlaps_dump_path):
66
+ duplicates[(task_name, task_set)] = pickle.load(
67
+ open(overlaps_dump_path, "rb")
68
+ )
69
+ sets_to_decontaminate -= 1
70
+ continue
71
+ else:
72
+ duplicates[(task_name, task_set)] = set()
73
+
74
+ # Build/load the task lookup {ngram: set(documents)}.
75
+ task_set_lookup_path = (
76
+ f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
77
+ )
78
+ if os.path.exists(task_set_lookup_path):
79
+ print(f"{task_set_lookup_path} available, loading...")
80
+ lookups[(task_name, task_set)] = pickle.load(
81
+ open(task_set_lookup_path, "rb")
82
+ )
83
+ else:
84
+ print(f"{task_set_lookup_path} not available, building...")
85
+ lookup = collections.defaultdict(set)
86
+
87
+ for doc_id, document in enumerate(docs):
88
+ ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
89
+ for ngram in ngrams:
90
+ lookup[ngram].add(doc_id)
91
+
92
+ pickle.dump(lookup, open(task_set_lookup_path, "wb"))
93
+ lookups[(task_name, task_set)] = lookup
94
+
95
+ elapsed = time.perf_counter() - start
96
+ print(f"Building lookups took {elapsed:0.5f} seconds.")
97
+
98
+ matched_ngrams = []
99
+
100
+ if sets_to_decontaminate > 0:
101
+ print("Merging lookups...")
102
+ start = time.perf_counter()
103
+ merged_lookup = collections.defaultdict(list)
104
+ for (task_name, task_set), lookup in lookups.items():
105
+ for ngram, doc_ids in lookup.items():
106
+ merged_lookup[ngram].append((task_name, task_set, doc_ids))
107
+
108
+ elapsed = time.perf_counter() - start
109
+ print(f"Merging lookups took {elapsed:0.5f} seconds.")
110
+
111
+ print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
112
+ files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
113
+ print(files)
114
+
115
+ for file in files:
116
+ start = time.perf_counter()
117
+ print(f"Scanning {file}")
118
+ reader = ZStdTextReader(file)
119
+ total_ngrams = 0
120
+ unique_ngrams = 0
121
+ matching_unique = 0
122
+ non_matching_unique = 0
123
+
124
+ current_ngram = ""
125
+ for line in reader.read_tqdm(): # Scan training set ngrams file
126
+ total_ngrams += 1
127
+ [ngram, document_id] = line.rsplit(" ", 1)
128
+ if (
129
+ ngram != current_ngram
130
+ ): # Only need to match the ngram once in training set
131
+ unique_ngrams += 1
132
+ current_ngram = ngram
133
+ if ngram in merged_lookup:
134
+ matched_ngrams.append(ngram) # For logging
135
+ matching_unique += 1
136
+ for task_name, task_set, doc_ids in merged_lookup[ngram]:
137
+ task_doc_set = duplicates[(task_name, task_set)]
138
+ for doc_id in doc_ids: # Record contamination across all relevant task/set combos
139
+ task_doc_set.add(doc_id)
140
+ del merged_lookup[ngram] # No point matching again
141
+ else:
142
+ non_matching_unique += 1
143
+
144
+ print(f"Total Ngrams: {total_ngrams}")
145
+ print(f"Unique Ngrams: {unique_ngrams}")
146
+ print(f"Unique Matching: {matching_unique}")
147
+ print(f"Unique Non Matching: {non_matching_unique}")
148
+ print("Matched ngrams:")
149
+ for ngram in matched_ngrams:
150
+ print(ngram)
151
+
152
+ elapsed = time.perf_counter() - start
153
+ print(f"Read took {elapsed:0.5f} seconds.")
154
+ print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
155
+
156
+ print(duplicates)
157
+
158
+ # Dump overlaps separately
159
+ for (task_name, task_set), doc_ids in duplicates.items():
160
+ overlaps_dump_path = get_overlaps_dump_path(
161
+ task_name, task_set, ngrams_n_size, limit
162
+ )
163
+ pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
164
+
165
+ # Strip task set and return
166
+ return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/decontamination/janitor.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import re
3
+ import string
4
+ import traceback
5
+ from typing import Iterator, List, Sequence, Tuple, TypeVar
6
+
7
+
8
+ # This is a cpp module. Compile janitor_util.cpp with:
9
+ # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
10
+ try:
11
+ import janitor_util
12
+
13
+ JANITOR_CPP = True
14
+ except Exception:
15
+ print("WARNING: C++ module could not be loaded. Janitor running in python mode")
16
+ traceback.print_exc()
17
+ JANITOR_CPP = False
18
+
19
+ T = TypeVar("T")
20
+
21
+
22
+ # Implementation from nltk source
23
+ # https://www.nltk.org/_modules/nltk/util.html
24
+ def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
25
+ history = []
26
+ while n > 1:
27
+ # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
28
+ try:
29
+ next_item = next(sequence)
30
+ except StopIteration:
31
+ # no more data, terminate the generator
32
+ return
33
+ history.append(next_item)
34
+ n -= 1
35
+ for item in sequence:
36
+ history.append(item)
37
+ yield tuple(history)
38
+ del history[0]
39
+
40
+
41
+ def word_ngrams(s: str, n: int) -> Iterator[str]:
42
+ """Splits a string into ngram words"""
43
+ tokens = s.split() # not a generator :(
44
+ ngram_seqs = form_ngrams(iter(tokens), n)
45
+ return (" ".join(ngram) for ngram in ngram_seqs)
46
+
47
+
48
+ # Does character sequences only - combined faster function to play around with later
49
+ # def word_ngrams_indices_combined(sequence, n):
50
+ # current_word = ""
51
+ # history = []
52
+ # gap = False;
53
+ # start = 0
54
+ # end = 0
55
+ # for character in sequence:
56
+ # if character == " ":
57
+ # if not gap:
58
+ # gap = True
59
+ # history.append(current_word)
60
+ # end += len(current_word) - 1
61
+ # current_word = ""
62
+ # if len(history) == n:
63
+ # yield (tuple(history), start, end)
64
+ # del history[0]
65
+ # start = end + 1
66
+ # end = start
67
+ # else:
68
+ # gap = False
69
+ # current_word += character
70
+
71
+
72
+ # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
73
+ def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
74
+ """Splits a string on whitespaces and records the indices of each in the original string.
75
+ @:return generator((word, (start_idx, end_idx)), ...)
76
+ """
77
+ return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
78
+
79
+
80
+ def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
81
+ """Splits a string into pairs of (ngram words, their start/end indices)"""
82
+ tokens_with_indices = split_indices(s)
83
+
84
+ # Generator of ngrams of (word, idx_pairs)
85
+ # (
86
+ # [(word, (start,end)), (word, (start, end))...],
87
+ # [(word, (start, end)), ...],
88
+ # ...
89
+ # )
90
+ ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
91
+
92
+ # Generator of pairs of word and index ngrams
93
+ # (
94
+ # ([word, word, ...], [(start,end), (start,end), ...]),
95
+ # ...
96
+ # )
97
+ ngram_indices_pairs = (
98
+ zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
99
+ )
100
+
101
+ # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
102
+ return (
103
+ (" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
104
+ for ngram_seq, indices in ngram_indices_pairs
105
+ )
106
+
107
+
108
+ class Janitor:
109
+ # FIXME delete_chars: Should anything else go here? Special chars?
110
+ def __init__(
111
+ self,
112
+ ngram_n: int = 13,
113
+ window_to_remove: int = 200,
114
+ too_dirty_cutoff: int = 10,
115
+ minimum_slice_length: int = 200,
116
+ delete_chars: str = string.punctuation,
117
+ ) -> None:
118
+ self.ngram_n = ngram_n
119
+ self.window_to_remove = window_to_remove
120
+ self.too_dirty_cutoff = too_dirty_cutoff
121
+ self.minimum_slice_length = minimum_slice_length
122
+ self.delete_chars = delete_chars
123
+
124
+ self.dirt_ngrams = set()
125
+
126
+ # If in python, we'll translate uppercase to lowercase and delete naughty characters.
127
+ # This is fast by python standards
128
+ # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
129
+ self.translation_table = str.maketrans(
130
+ string.ascii_lowercase + string.ascii_uppercase, # These characters
131
+ string.ascii_lowercase * 2, # Become these characters
132
+ self.delete_chars, # These are deleted
133
+ )
134
+
135
+ ##############
136
+ # I/O for saving contamination ngrams
137
+ ##############
138
+
139
+ def save_contamination_ngrams(self, filename: str) -> None:
140
+ with open(filename, "wb") as fp:
141
+ pickle.dump(filename, fp)
142
+
143
+ def load_contamination_ngrams(self, filename: str) -> None:
144
+ with open(filename, "rb") as fp:
145
+ self.dirt_ngrams = pickle.load(fp)
146
+
147
+ ##############
148
+ # Call these :)
149
+ ##############
150
+
151
+ def register_contaminant(self, dirt_string: str) -> None:
152
+ """Register a string as contamination to be removed, e.g. a test set
153
+ This breaks the dirt_string into ngrams to store for future cleaning"""
154
+ if JANITOR_CPP:
155
+ return self.register_contaminant_cpp(dirt_string)
156
+ else:
157
+ print("WARNING: Janitor running in python mode")
158
+ return self.register_contaminant_python(dirt_string)
159
+
160
+ def clean(self, dirty_string: str) -> List[str]:
161
+ """Clean a string (e.g. a training set) by removing all ngrams previously
162
+ registered as contaminants. Returns a list of clean chunks, or empty if
163
+ the string was too dirty"""
164
+ if JANITOR_CPP:
165
+ return self.clean_cpp(dirty_string)
166
+ else:
167
+ print("WARNING: Janitor running in python mode")
168
+ return self.clean_python(dirty_string)
169
+
170
+ def _split_chunks(
171
+ self, dirty_string: str, dirty_parts: Sequence[Tuple]
172
+ ) -> List[str]:
173
+ clean_chunks = []
174
+ splice_idx = 0
175
+ end = -1
176
+ for i, (ngram, start, end) in enumerate(dirty_parts):
177
+ if i >= self.too_dirty_cutoff:
178
+ return []
179
+ start = max(0, start - self.window_to_remove)
180
+ end = min(len(dirty_string), end + self.window_to_remove)
181
+
182
+ if start - splice_idx > self.minimum_slice_length:
183
+ clean_chunks.append(dirty_string[splice_idx:start])
184
+ splice_idx = end
185
+
186
+ if end < len(dirty_string) - self.minimum_slice_length:
187
+ clean_chunks.append(dirty_string[end + 1 :])
188
+
189
+ return clean_chunks
190
+
191
+ ##############
192
+ # Fast C++
193
+ ##############
194
+
195
+ def register_contaminant_cpp(self, dirt_string) -> None:
196
+ self.dirt_ngrams.update(
197
+ janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
198
+ )
199
+
200
+ def clean_cpp(self, dirty_string: str) -> List[str]:
201
+ contamination_indices = janitor_util.clean_ngram_with_indices(
202
+ dirty_string, self.delete_chars, self.ngram_n
203
+ )
204
+ return self._split_chunks(dirty_string, contamination_indices)
205
+
206
+ ##############
207
+ # Slow python
208
+ ##############
209
+
210
+ def normalize_string(self, s: str) -> str:
211
+ return s.translate(self.translation_table)
212
+
213
+ def register_contaminant_python(self, dirt_string: str) -> None:
214
+ self.dirt_ngrams.update(
215
+ word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
216
+ )
217
+
218
+ def clean_python(self, dirty_string: str) -> List[str]:
219
+ contamination_indices = (
220
+ (None, *idx_pair)
221
+ for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
222
+ if self.normalize_string(dirty_ngram) in self.dirt_ngrams
223
+ )
224
+ return self._split_chunks(dirty_string, contamination_indices)
225
+
226
+
227
+ ##################################################################
228
+ # Tests
229
+ #################################################################
230
+
231
+ # def print_cpp():
232
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
233
+
234
+ # for i in range(1, 10, 2):
235
+ # pprint(janitor_util.clean_ngram(source, string.punctuation, i))
236
+ # for ngram, start, end in \
237
+ # janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
238
+ # print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
239
+
240
+
241
+ # def test_cpp():
242
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
243
+ # contaminant = "dirty boy. Clean he he"
244
+
245
+ # jan_python = Janitor()
246
+ # jan_cpp = Janitor()
247
+
248
+ # jan_python.register_contaminant_python(contaminant)
249
+ # jan_cpp.register_contaminant(contaminant)
250
+
251
+ # assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
252
+
253
+ # assert jan_python.clean_python(source) == jan_cpp.clean(source), \
254
+ # (jan_python.clean_python(source), jan_cpp.clean(source))
255
+
256
+ # print("Passed test, python==cpp")
257
+
258
+
259
+ # def benchmark():
260
+ # # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
261
+ # setup = \
262
+ # """
263
+ # with open("data/enwik8", "r") as f:
264
+ # data = f.read()
265
+ # jan = Janitor(too_dirty_cutoff=1000)
266
+ # jan.register_contaminant('''
267
+ # theories is that there is a connection between &quot;geekdom&quot; and autism.
268
+ # This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
269
+ # The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
270
+ # movement{{ref|Wired}}. This article, many professionals assert, is just one example of
271
+ # the media's application of mental disease labels to what is actually variant normal behavior
272
+ # &amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
273
+ # interests, even when they seem unusual to others, are not in themselves signs of autism or
274
+ # Asperger's syndrome. Others assert that it is actually the medical profession which is applying
275
+ # mental disease labels to children who in the past would have simply been accepted as a little
276
+ # different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
277
+ # Due to the recent publicity surrounding autism and autis
278
+ # ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
279
+ # oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
280
+ # paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
281
+ # would last, took a cautious approach, preferring to save the revenue rather than investing it in
282
+ # development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
283
+ # to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
284
+ # brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
285
+ # with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
286
+ # ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
287
+ # ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
288
+ # Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
289
+ # [[United Arab Emirates]]. After the Emirates gained independence in 1971,
290
+ # ''')
291
+ # """
292
+
293
+ # n = 1
294
+ # print(f"Timing {n} run on 100 MB")
295
+ # print("Register contaminant")
296
+ # # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
297
+ # print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
298
+
299
+ # print("Clean")
300
+ # # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
301
+ # print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
302
+
303
+
304
+ # def test_janitor_general():
305
+ # source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
306
+ # contaminant = "dirty boy. Clean he he"
307
+
308
+ # jan = Janitor(ngram_n=3)
309
+ # jan.register_contaminant(contaminant)
310
+ # cleaned = " ".join(jan.clean(source))
311
+ # for contam in jan.dirt_ngrams:
312
+ # assert contam not in cleaned, contam
313
+
314
+ # filename = "data/saved_contam"
315
+ # jan.save_contamination_ngrams(filename)
316
+
317
+ # jan = Janitor(ngram_n=3)
318
+ # jan.load_contamination_ngrams(filename)
319
+ # cleaned = " ".join(jan.clean(source))
320
+ # for contam in jan.dirt_ngrams:
321
+ # assert contam not in cleaned, contam
322
+
323
+
324
+ # if __name__ == "__main__":
325
+ # test()
326
+ # # print_cpp()
327
+ # # test_cpp()
328
+ # # benchmark()
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/evaluator.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import random
5
+ import time
6
+ from collections import defaultdict
7
+ from typing import TYPE_CHECKING, List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ import lm_eval.api.metrics
13
+ import lm_eval.api.registry
14
+ import lm_eval.api.task
15
+ import lm_eval.models
16
+ from lm_eval.caching.cache import delete_cache
17
+ from lm_eval.evaluator_utils import (
18
+ consolidate_group_results,
19
+ consolidate_results,
20
+ get_sample_size,
21
+ get_subtask_list,
22
+ get_task_list,
23
+ prepare_print_tasks,
24
+ print_writeout,
25
+ run_task_tests,
26
+ )
27
+ from lm_eval.loggers import EvaluationTracker
28
+ from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
29
+ from lm_eval.tasks import (
30
+ TaskManager,
31
+ get_task_dict,
32
+ )
33
+ from lm_eval.utils import (
34
+ handle_non_serializable,
35
+ hash_string,
36
+ positional_deprecated,
37
+ setup_logging,
38
+ simple_parse_args_string,
39
+ )
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ from lm_eval.api.model import LM
44
+ from lm_eval.api.task import Task
45
+
46
+ eval_logger = logging.getLogger(__name__)
47
+
48
+
49
+ @positional_deprecated
50
+ def simple_evaluate(
51
+ model,
52
+ model_args: Optional[Union[str, dict]] = None,
53
+ tasks: Optional[List[Union[str, dict, object]]] = None,
54
+ num_fewshot: Optional[int] = None,
55
+ batch_size: Optional[Union[int, str]] = None,
56
+ max_batch_size: Optional[int] = None,
57
+ device: Optional[str] = None,
58
+ use_cache: Optional[str] = None,
59
+ cache_requests: bool = False,
60
+ rewrite_requests_cache: bool = False,
61
+ delete_requests_cache: bool = False,
62
+ limit: Optional[Union[int, float]] = None,
63
+ bootstrap_iters: int = 100000,
64
+ check_integrity: bool = False,
65
+ write_out: bool = False,
66
+ log_samples: bool = True,
67
+ evaluation_tracker: Optional[EvaluationTracker] = None,
68
+ system_instruction: Optional[str] = None,
69
+ apply_chat_template: Union[bool, str] = False,
70
+ fewshot_as_multiturn: bool = False,
71
+ gen_kwargs: Union[str, dict, None] = None,
72
+ task_manager: Optional[TaskManager] = None,
73
+ verbosity=None,
74
+ predict_only: bool = False,
75
+ random_seed: int = 0,
76
+ numpy_random_seed: int = 1234,
77
+ torch_random_seed: int = 1234,
78
+ fewshot_random_seed: int = 1234,
79
+ confirm_run_unsafe_code: bool = False,
80
+ metadata: Optional[dict] = None,
81
+ ):
82
+ """Instantiate and evaluate a model on a list of tasks.
83
+
84
+ :param model: Union[str, LM]
85
+ Name of model or LM object, see lm_eval.models.get_model
86
+ :param model_args: Optional[str, dict]
87
+ String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
88
+ Ignored if `model` argument is a LM object.
89
+ :param tasks: list[Union[str, dict, Task]]
90
+ List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
91
+ :param num_fewshot: int
92
+ Number of examples in few-shot context
93
+ :param batch_size: int or str, optional
94
+ Batch size for model
95
+ :param max_batch_size: int, optional
96
+ Maximal batch size to try with automatic batch size detection
97
+ :param device: str, optional
98
+ PyTorch device (e.g. "cpu" or "cuda:0") for running models
99
+ :param use_cache: str, optional
100
+ A path to a sqlite db file for caching model responses. `None` if not caching.
101
+ :param cache_requests: bool, optional
102
+ Speed up evaluation by caching the building of dataset requests. `None` if not caching.
103
+ :param rewrite_requests_cache: bool, optional
104
+ Rewrites all the request cache if set to `True`. `None` if not desired.
105
+ :param delete_requests_cache: bool, optional
106
+ Deletes all the request cache if set to `True`. `None` if not desired.
107
+ :param limit: int or float, optional
108
+ Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
109
+ :param bootstrap_iters:
110
+ Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
111
+ :param check_integrity: bool
112
+ Whether to run the relevant part of the test suite for the tasks
113
+ :param write_out: bool
114
+ If True, write out an example document and model input for checking task integrity
115
+ :param log_samples: bool
116
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
117
+ :param system_instruction: str
118
+ System instruction to be applied to the prompt
119
+ :param apply_chat_template: Union[bool, str]
120
+ Specifies whether to apply a chat template to the prompt.
121
+ - If set to True, the default chat template is applied.
122
+ - If set to a string, applies the specified chat template by name.
123
+ Defaults to False (no chat template applied).
124
+ :param fewshot_as_multiturn: bool
125
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
126
+ :param gen_kwargs: dict or comma-separated string
127
+ Arguments for model generation
128
+ Ignored for all tasks with loglikelihood output_type
129
+ :param verbosity: str
130
+ Verbosity level for logging
131
+ :param predict_only: bool
132
+ If true only model outputs will be generated and returned. Metrics will not be evaluated
133
+ :param random_seed: int
134
+ Random seed for python's random module. If set to None, the seed will not be set.
135
+ :param numpy_random_seed: int
136
+ Random seed for numpy. If set to None, the seed will not be set.
137
+ :param torch_random_seed: int
138
+ Random seed for torch. If set to None, the seed will not be set.
139
+ :param fewshot_random_seed: int
140
+ Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
141
+ :param metadata: dict
142
+ Additional metadata to be added to the task manager. Will get passed to the download function of the task.
143
+
144
+ return
145
+ Dictionary of results
146
+ """
147
+ if verbosity is not None:
148
+ setup_logging(verbosity=verbosity)
149
+ start_date = time.time()
150
+
151
+ if isinstance(model_args, str) and (
152
+ "instruct" in model_args and not apply_chat_template
153
+ ):
154
+ eval_logger.warning(
155
+ "Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
156
+ )
157
+
158
+ if delete_requests_cache:
159
+ eval_logger.info("Deleting requests cache...")
160
+ delete_cache()
161
+
162
+ seed_message = []
163
+ if random_seed is not None:
164
+ # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
165
+ seed_message.append(f"Setting random seed to {random_seed}")
166
+ random.seed(random_seed)
167
+
168
+ if numpy_random_seed is not None:
169
+ seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
170
+ np.random.seed(numpy_random_seed)
171
+
172
+ if torch_random_seed is not None:
173
+ seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
174
+ torch.manual_seed(torch_random_seed)
175
+
176
+ if fewshot_random_seed is not None:
177
+ seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
178
+
179
+ if seed_message:
180
+ eval_logger.info(" | ".join(seed_message))
181
+
182
+ if tasks is None:
183
+ tasks = []
184
+ if len(tasks) == 0:
185
+ raise ValueError(
186
+ "No tasks specified, or no tasks found. Please verify the task names."
187
+ )
188
+
189
+ if gen_kwargs is not None:
190
+ if isinstance(gen_kwargs, str):
191
+ gen_kwargs = simple_parse_args_string(gen_kwargs)
192
+ eval_logger.warning(
193
+ f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
194
+ "Ensure 'do_sample=True' for non-greedy decoding!"
195
+ )
196
+ if not gen_kwargs:
197
+ gen_kwargs = None
198
+
199
+ if isinstance(model, str):
200
+ if model_args is None:
201
+ eval_logger.warning("model_args not specified. Using defaults.")
202
+ model_args = ""
203
+
204
+ if isinstance(model_args, dict):
205
+ eval_logger.info(
206
+ f"Initializing {model} model, with arguments: {model_args}"
207
+ )
208
+ lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
209
+ model_args,
210
+ {
211
+ "batch_size": batch_size,
212
+ "max_batch_size": max_batch_size,
213
+ "device": device,
214
+ },
215
+ )
216
+
217
+ else:
218
+ eval_logger.info(
219
+ f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
220
+ )
221
+ lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
222
+ model_args,
223
+ {
224
+ "batch_size": batch_size,
225
+ "max_batch_size": max_batch_size,
226
+ "device": device,
227
+ },
228
+ )
229
+ else:
230
+ if not isinstance(model, lm_eval.api.model.LM):
231
+ raise TypeError(
232
+ f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
233
+ )
234
+ eval_logger.info("Using pre-initialized model")
235
+ lm = model
236
+
237
+ if use_cache is not None:
238
+ eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
239
+ lm = lm_eval.api.model.CachingLM(
240
+ lm,
241
+ use_cache
242
+ # each rank receives a different cache db.
243
+ # necessary to avoid multiple writes to cache at once
244
+ + "_rank"
245
+ + str(lm.rank)
246
+ + ".db",
247
+ )
248
+
249
+ if task_manager is None:
250
+ metadata = (
251
+ simple_parse_args_string(model_args)
252
+ if isinstance(model_args, str)
253
+ else model_args
254
+ if isinstance(model_args, dict)
255
+ else {}
256
+ ) | (metadata or {})
257
+ task_manager = TaskManager(metadata=metadata)
258
+
259
+ task_dict = get_task_dict(
260
+ tasks,
261
+ task_manager,
262
+ )
263
+
264
+ # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
265
+ # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
266
+ def _adjust_config(task_dict):
267
+ adjusted_task_dict = {}
268
+ for task_name, task_obj in task_dict.items():
269
+ if isinstance(task_obj, dict):
270
+ adjusted_task_dict = {
271
+ **adjusted_task_dict,
272
+ **{task_name: _adjust_config(task_obj)},
273
+ }
274
+
275
+ else:
276
+ if task_obj.get_config("output_type") == "generate_until":
277
+ if gen_kwargs is not None:
278
+ task_obj.set_config(
279
+ key="generation_kwargs", value=gen_kwargs, update=True
280
+ )
281
+ eval_logger.info(
282
+ f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
283
+ )
284
+
285
+ if predict_only:
286
+ eval_logger.info(
287
+ f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
288
+ )
289
+ # we have to change the class properties post-hoc. This is pretty hacky.
290
+ task_obj.override_metric(metric_name="bypass")
291
+
292
+ # override tasks' fewshot values to the provided num_fewshot arg value
293
+ # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
294
+ if num_fewshot is not None:
295
+ if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
296
+ eval_logger.info(
297
+ f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
298
+ )
299
+ else:
300
+ eval_logger.warning(
301
+ f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
302
+ )
303
+ task_obj.set_config(key="num_fewshot", value=num_fewshot)
304
+ else:
305
+ # if num_fewshot not provided, and the task does not define a default one, default to 0
306
+ if (
307
+ default_num_fewshot := task_obj.get_config("num_fewshot")
308
+ ) is None:
309
+ task_obj.set_config(key="num_fewshot", value=0)
310
+ # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
311
+ task_obj.set_fewshot_seed(seed=fewshot_random_seed)
312
+
313
+ adjusted_task_dict[task_name] = task_obj
314
+
315
+ return adjusted_task_dict
316
+
317
+ task_dict = _adjust_config(task_dict)
318
+
319
+ if check_integrity:
320
+ run_task_tests(task_list=tasks)
321
+
322
+ if evaluation_tracker is not None:
323
+ evaluation_tracker.general_config_tracker.log_experiment_args(
324
+ model_source=model,
325
+ model_args=model_args,
326
+ system_instruction=system_instruction,
327
+ chat_template=lm.chat_template(apply_chat_template)
328
+ if apply_chat_template
329
+ else None,
330
+ fewshot_as_multiturn=fewshot_as_multiturn,
331
+ )
332
+
333
+ results = evaluate(
334
+ lm=lm,
335
+ task_dict=task_dict,
336
+ limit=limit,
337
+ cache_requests=cache_requests,
338
+ rewrite_requests_cache=rewrite_requests_cache,
339
+ bootstrap_iters=bootstrap_iters,
340
+ write_out=write_out,
341
+ log_samples=True if predict_only else log_samples,
342
+ system_instruction=system_instruction,
343
+ apply_chat_template=apply_chat_template,
344
+ fewshot_as_multiturn=fewshot_as_multiturn,
345
+ verbosity=verbosity,
346
+ confirm_run_unsafe_code=confirm_run_unsafe_code,
347
+ )
348
+ if verbosity is not None:
349
+ setup_logging(verbosity=verbosity)
350
+
351
+ if lm.rank == 0:
352
+ if isinstance(model, str):
353
+ model_name = model
354
+ elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
355
+ model_name = model.config._name_or_path
356
+ else:
357
+ model_name = type(model).__name__
358
+
359
+ # add info about the model and few shot config
360
+ results["config"] = {
361
+ "model": model_name,
362
+ "model_args": model_args,
363
+ }
364
+ # add more detailed model info if available
365
+ if isinstance(lm, lm_eval.models.huggingface.HFLM):
366
+ results["config"].update(lm.get_model_info())
367
+ # add info about execution
368
+ results["config"].update(
369
+ {
370
+ "batch_size": batch_size,
371
+ "batch_sizes": (
372
+ list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
373
+ ),
374
+ "device": device,
375
+ "use_cache": use_cache,
376
+ "limit": limit,
377
+ "bootstrap_iters": bootstrap_iters,
378
+ "gen_kwargs": gen_kwargs,
379
+ "random_seed": random_seed,
380
+ "numpy_seed": numpy_random_seed,
381
+ "torch_seed": torch_random_seed,
382
+ "fewshot_seed": fewshot_random_seed,
383
+ }
384
+ )
385
+ results["git_hash"] = get_git_commit_hash()
386
+ results["date"] = start_date
387
+ add_env_info(results) # additional environment info to results
388
+ add_tokenizer_info(results, lm) # additional info about tokenizer
389
+ return results
390
+ else:
391
+ return None
392
+
393
+
394
+ @positional_deprecated
395
+ def evaluate(
396
+ lm: "LM",
397
+ task_dict,
398
+ limit: Optional[int] = None,
399
+ cache_requests: bool = False,
400
+ rewrite_requests_cache: bool = False,
401
+ bootstrap_iters: Optional[int] = 100000,
402
+ write_out: bool = False,
403
+ log_samples: bool = True,
404
+ system_instruction: Optional[str] = None,
405
+ apply_chat_template: Union[bool, str] = False,
406
+ fewshot_as_multiturn: bool = False,
407
+ verbosity: str = "INFO",
408
+ confirm_run_unsafe_code: bool = False,
409
+ ):
410
+ """Instantiate and evaluate a model on a list of tasks.
411
+
412
+ :param lm: obj
413
+ Language Model
414
+ :param task_dict: dict[str, Task]
415
+ Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
416
+ :param limit: int, optional
417
+ Limit the number of examples per task (only use this for testing)
418
+ :param cache_requests: bool, optional
419
+ Speed up evaluation by caching the building of dataset requests.
420
+ :param rewrite_requests_cache: bool, optional
421
+ Rewrites all the request cache if set to `True`.
422
+ :param bootstrap_iters:
423
+ Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
424
+ :param write_out: bool
425
+ If True, write out an example document and model input for checking task integrity
426
+ :param log_samples: bool
427
+ If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
428
+ :param system_instruction: str
429
+ System instruction to be applied to the prompt
430
+ :param apply_chat_template: Union[bool, str]
431
+ Specifies whether to apply a chat template to the prompt.
432
+ - If set to True, the default chat template is applied.
433
+ - If set to a string, applies the specified chat template by name.
434
+ Defaults to False (no chat template applied).
435
+ :param fewshot_as_multiturn: bool
436
+ Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
437
+ :param verbosity: str
438
+ Verbosity level for logging
439
+ :param confirm_run_unsafe_code: bool
440
+ Whether to confirm running tasks marked as unsafe.
441
+ :return
442
+ Dictionary of results
443
+ """
444
+
445
+ if apply_chat_template:
446
+ eval_logger.warning(
447
+ "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
448
+ )
449
+
450
+ # tracks all Instances/requests a model must generate output on.
451
+ requests = defaultdict(list)
452
+ # stores the amount to pad out reqs per req. type so that
453
+ # number of fwd passes per distributed rank is equal
454
+ padding_requests = defaultdict(int)
455
+
456
+ # get lists of group hierarchy and each type of request
457
+ eval_tasks = get_task_list(task_dict)
458
+ if not log_samples:
459
+ if not all(
460
+ "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
461
+ for task_output in eval_tasks
462
+ ):
463
+ raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
464
+
465
+ # validation checks:
466
+ # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
467
+ # 2.are we running code that is marked as unsafe.
468
+ incompatible_tasks = []
469
+ for task_output in eval_tasks:
470
+ task: Task = task_output.task
471
+
472
+ if getattr(lm, "MULTIMODAL", False) != getattr(task, "MULTIMODAL", False):
473
+ incompatible_tasks.append(task_output.task_name)
474
+ elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
475
+ raise ValueError(
476
+ f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
477
+ )
478
+ if len(incompatible_tasks) > 0:
479
+ if not getattr(lm, "MULTIMODAL", False):
480
+ raise ValueError(
481
+ f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
482
+ )
483
+ else:
484
+ raise ValueError(
485
+ f"Attempted to run tasks: {incompatible_tasks} which are text-only, but used a model type which only currently supports multimodal tasks."
486
+ )
487
+ # end validation check
488
+
489
+ # Cache the limit arg.
490
+ limit_arg = limit
491
+ limits = []
492
+ for task_output in eval_tasks:
493
+ task: Task = task_output.task
494
+
495
+ limit = get_sample_size(task, limit_arg)
496
+ limits.append(limit)
497
+ task.build_all_requests(
498
+ limit=limit,
499
+ rank=lm.rank,
500
+ world_size=lm.world_size,
501
+ cache_requests=cache_requests,
502
+ rewrite_requests_cache=rewrite_requests_cache,
503
+ system_instruction=system_instruction,
504
+ apply_chat_template=bool(apply_chat_template),
505
+ fewshot_as_multiturn=fewshot_as_multiturn,
506
+ chat_template=getattr(lm, "apply_chat_template")
507
+ if apply_chat_template
508
+ else None,
509
+ tokenizer_name=getattr(lm, "tokenizer_name", "")
510
+ if apply_chat_template
511
+ else "",
512
+ )
513
+ eval_logger.debug(
514
+ f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
515
+ )
516
+ if write_out:
517
+ print_writeout(task)
518
+ # aggregate Instances by LM method requested to get output.
519
+ for instance in task.instances:
520
+ reqtype = instance.request_type
521
+ requests[reqtype].append(instance)
522
+
523
+ if lm.world_size > 1:
524
+ instances_rnk = torch.tensor(len(task._instances), device=lm.device)
525
+ gathered_item = (
526
+ lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
527
+ )
528
+ # "multiple_choice" task types dispatch (several) "loglikelihood" request types
529
+ reqtype = (
530
+ "loglikelihood"
531
+ if task.OUTPUT_TYPE == "multiple_choice"
532
+ else task.OUTPUT_TYPE
533
+ )
534
+ # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
535
+ numpad = max(gathered_item) - gathered_item[lm.rank]
536
+ # todo: may not account for padding in cases like SquadV2 which has multiple req types
537
+ padding_requests[reqtype] += numpad
538
+
539
+ ### Run LM on inputs, get all outputs ###
540
+ # execute each type of request
541
+ for reqtype, reqs in requests.items():
542
+ eval_logger.info(f"Running {reqtype} requests")
543
+ # create `K` copies of each request `req` based off `K = req.repeats`
544
+ cloned_reqs = []
545
+ for req in reqs:
546
+ cloned_reqs.extend([req] * req.repeats)
547
+
548
+ if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
549
+ for _ in range(padding_requests[reqtype]):
550
+ cloned_reqs.extend([req] * req.repeats)
551
+
552
+ # run requests through model
553
+ resps = getattr(lm, reqtype)(cloned_reqs)
554
+
555
+ # put responses from model into a list of length K for each request.
556
+ for x, req in zip(resps, cloned_reqs):
557
+ req.resps.append(x)
558
+
559
+ if lm.world_size > 1:
560
+ lm.accelerator.wait_for_everyone()
561
+
562
+ RANK = lm.rank
563
+ WORLD_SIZE = lm.world_size
564
+ ### Postprocess outputs ###
565
+ # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
566
+ for task_output, limit in zip(eval_tasks, limits):
567
+ task = task_output.task
568
+ task.apply_filters()
569
+
570
+ ### Collect values of metrics on all datapoints ###
571
+ # # unpack results and sort back in order and return control to Task
572
+ # TODO: make it possible to use a different metric per filter
573
+ # Pre-process task.instances to group by doc_id
574
+ instances_by_doc_id = defaultdict(list)
575
+ for instance in task.instances:
576
+ instances_by_doc_id[instance.doc_id].append(instance)
577
+ # Sort instances within each group
578
+ for instances in instances_by_doc_id.values():
579
+ instances.sort(key=lambda x: x.idx)
580
+ # iterate over different filters used
581
+ for filter_key in task.instances[0].filtered_resps.keys():
582
+ doc_iterator = task.doc_iterator(
583
+ rank=RANK, limit=limit, world_size=WORLD_SIZE
584
+ )
585
+ for doc_id, doc in doc_iterator:
586
+ requests = instances_by_doc_id[doc_id]
587
+ metrics = task.process_results(
588
+ doc, [req.filtered_resps[filter_key] for req in requests]
589
+ )
590
+ if log_samples:
591
+ target = task.doc_to_target(doc)
592
+ example = {
593
+ "doc_id": doc_id,
594
+ "doc": doc,
595
+ "target": target,
596
+ "arguments": [req.args for req in requests],
597
+ "resps": [req.resps for req in requests],
598
+ "filtered_resps": [
599
+ req.filtered_resps[filter_key] for req in requests
600
+ ],
601
+ "filter": filter_key,
602
+ "metrics": list(metrics.keys()),
603
+ "doc_hash": hash_string(
604
+ json.dumps(
605
+ requests[0].doc,
606
+ indent=2,
607
+ default=handle_non_serializable,
608
+ ensure_ascii=False,
609
+ )
610
+ ),
611
+ "prompt_hash": hash_string(requests[0].arguments[0]),
612
+ "target_hash": hash_string(str(target)),
613
+ }
614
+ example.update(metrics)
615
+ task_output.logged_samples.append(example)
616
+ for metric, value in metrics.items():
617
+ task_output.sample_metrics[(metric, filter_key)].append(value)
618
+
619
+ if WORLD_SIZE > 1:
620
+ # if multigpu, then gather data across all ranks to rank 0
621
+ # first gather logged samples across all ranks
622
+ for task_output in eval_tasks:
623
+ if log_samples:
624
+ # for task_name, task_samples in list(samples.items()):
625
+ full_samples = [None] * WORLD_SIZE if RANK == 0 else None
626
+ torch.distributed.gather_object(
627
+ obj=task_output.logged_samples,
628
+ object_gather_list=full_samples,
629
+ dst=0,
630
+ )
631
+
632
+ if RANK == 0:
633
+ task_output.logged_samples = list(
634
+ itertools.chain.from_iterable(full_samples)
635
+ )
636
+
637
+ # then collect metrics across all ranks
638
+ for metrics in task_output.sample_metrics:
639
+ metric_list = [None] * WORLD_SIZE if RANK == 0 else None
640
+ torch.distributed.gather_object(
641
+ obj=task_output.sample_metrics[metrics],
642
+ object_gather_list=metric_list,
643
+ dst=0,
644
+ )
645
+ if RANK == 0:
646
+ task_output.sample_metrics[metrics] = list(
647
+ itertools.chain.from_iterable(metric_list)
648
+ )
649
+
650
+ if RANK == 0:
651
+ ### Aggregate results over all datapoints ###
652
+ # aggregate results ; run bootstrap CIs
653
+ for task_output in eval_tasks:
654
+ task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
655
+ (
656
+ results,
657
+ samples,
658
+ configs,
659
+ versions,
660
+ num_fewshot,
661
+ higher_is_better,
662
+ ) = consolidate_results(eval_tasks)
663
+
664
+ ### Calculate group metrics ###
665
+ if bool(results):
666
+ results, versions, show_group_table, *_ = consolidate_group_results(
667
+ results, versions, task_dict
668
+ )
669
+
670
+ results_agg, group_agg = prepare_print_tasks(task_dict, results)
671
+ subtask_list = get_subtask_list(task_dict)
672
+
673
+ # collect all higher_is_better values for metrics
674
+ # in the group's subtasks.
675
+ # TODO: clean this up ; unify with the below metric_list loop?
676
+ _higher_is_better = {}
677
+ for group, task_list in subtask_list.items():
678
+ if (
679
+ len(task_list) != 0
680
+ ): # subtask list will list "task_name": [] for solo tasks
681
+ for task in task_list:
682
+ for m, h in higher_is_better[task].items():
683
+ if m not in _higher_is_better.keys():
684
+ _higher_is_better[m] = h
685
+
686
+ if (
687
+ m in _higher_is_better
688
+ and _higher_is_better[m] is not None
689
+ and _higher_is_better[m] != h
690
+ ):
691
+ eval_logger.warning(
692
+ f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
693
+ )
694
+ _higher_is_better[m] = None
695
+ higher_is_better[group] = _higher_is_better
696
+
697
+ results_dict = {
698
+ "results": dict(results_agg.items()),
699
+ **(
700
+ {"groups": dict(group_agg.items())}
701
+ if (bool(group_agg) & show_group_table)
702
+ else {}
703
+ ),
704
+ "group_subtasks": dict(reversed(subtask_list.items())),
705
+ "configs": dict(sorted(configs.items())),
706
+ "versions": dict(sorted(versions.items())),
707
+ "n-shot": dict(sorted(num_fewshot.items())),
708
+ "higher_is_better": dict(sorted(higher_is_better.items())),
709
+ "n-samples": {
710
+ task_output.task_name: {
711
+ "original": len(task_output.task.eval_docs),
712
+ "effective": min(
713
+ limit if limit else len(task_output.task.eval_docs),
714
+ len(task_output.task.eval_docs),
715
+ ),
716
+ }
717
+ for task_output, limit in zip(eval_tasks, limits)
718
+ },
719
+ }
720
+ if log_samples:
721
+ results_dict["samples"] = dict(samples)
722
+
723
+ return results_dict
724
+
725
+ else:
726
+ return None
727
+
728
+
729
+ def request_caching_arg_to_dict(cache_requests: str) -> dict:
730
+ request_caching_args = {
731
+ "cache_requests": cache_requests in {"true", "refresh"},
732
+ "rewrite_requests_cache": cache_requests == "refresh",
733
+ "delete_requests_cache": cache_requests == "delete",
734
+ }
735
+
736
+ return request_caching_args
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/evaluator_utils.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import math
4
+ import pathlib
5
+ import sys
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ from lm_eval.api.group import ConfigurableGroup
9
+ from lm_eval.api.metrics import (
10
+ aggregate_subtask_metrics,
11
+ mean,
12
+ pooled_sample_stderr,
13
+ stderr_for_metric,
14
+ )
15
+ from lm_eval.api.task import Task
16
+ from lm_eval.utils import positional_deprecated
17
+
18
+
19
+ eval_logger = logging.getLogger(__name__)
20
+
21
+
22
+ class TaskOutput:
23
+ """
24
+ Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
25
+
26
+ Attributes:
27
+ task (object): The task object.
28
+ task_name (str): The name of the task.
29
+ task_config (dict): The configuration of the task.
30
+ version (str): The version of the task.
31
+ group_name (str): The name of the task group.
32
+ n_shot (int): The number of shots for the task.
33
+ task_alias (str): The alias of the task.
34
+ group_alias (str): The alias of the task group.
35
+ is_group (bool): Indicates if the task is a group.
36
+ logged_samples (list): The list of logged samples.
37
+ sample_len (int): The length of the samples.
38
+ sample_metrics (defaultdict): The dictionary of samples' metrics.
39
+ agg_metrics (defaultdict): The dictionary of aggregate metrics.
40
+
41
+ Methods:
42
+ from_taskdict(cls, task_name: str, task):
43
+ Creates a TaskOutput instance from a task dictionary.
44
+
45
+ calculate_aggregate_metric(bootstrap_iters=100000) -> None:
46
+ Calculates the aggregate metrics for the task.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ task=None,
52
+ task_name=None,
53
+ task_config=None,
54
+ version=None,
55
+ group_name=None,
56
+ n_shot=None,
57
+ task_alias=None,
58
+ group_alias=None,
59
+ is_group=None,
60
+ ):
61
+ self.task = task
62
+ self.task_config = task_config
63
+ self.task_name = task_name
64
+ self.group_name = group_name
65
+ self.version = version
66
+ self.n_shot = n_shot
67
+ self.task_alias = task_alias
68
+ self.group_alias = group_alias
69
+ self.is_group = is_group
70
+ self.logged_samples = []
71
+ self.sample_len = None
72
+ self.sample_metrics = collections.defaultdict(list)
73
+ self.agg_metrics = collections.defaultdict(list)
74
+
75
+ @classmethod
76
+ def from_taskdict(cls, task_name: str, task):
77
+ if isinstance(task, tuple):
78
+ group_name, task = task
79
+ else:
80
+ group_name = None
81
+ if not task:
82
+ # these gets filtered out in get_task_list
83
+ # once they are added to group hierarchy
84
+ is_group = True
85
+ return cls(
86
+ task=task, task_name=task_name, is_group=is_group, group_name=group_name
87
+ )
88
+ version = task.VERSION
89
+ task_config = dict(task.dump_config())
90
+ if (n_shot := task_config.get("num_fewshot")) == 0:
91
+ n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
92
+ task_alias = task_config.get("alias")
93
+ group_alias = task_config.get("group_alias")
94
+ return cls(
95
+ task=task,
96
+ task_name=task_name,
97
+ task_config=task_config,
98
+ group_name=group_name,
99
+ version=version,
100
+ n_shot=n_shot,
101
+ task_alias=task_alias,
102
+ group_alias=group_alias,
103
+ )
104
+
105
+ def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
106
+ for (metric, filter_key), items in self.sample_metrics.items():
107
+ try:
108
+ agg_fn = self.task.aggregation()[metric]
109
+ except KeyError:
110
+ # This is when process results output an arbitrary metric
111
+ # TODO: Handle this better and allow other aggregate functions other than mean.
112
+ agg_fn = mean
113
+ metric_key = f"{metric},{filter_key}"
114
+ self.agg_metrics[metric_key] = agg_fn(items)
115
+ self.sample_len = len(items) # TODO: same sample size for each metric?
116
+ if isinstance(bootstrap_iters, int):
117
+ stderr_fn = stderr_for_metric(
118
+ metric=agg_fn,
119
+ bootstrap_iters=min(bootstrap_iters, 100)
120
+ if metric in ["bleu", "chrf", "ter"]
121
+ else bootstrap_iters,
122
+ )
123
+ self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
124
+ stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
125
+ )
126
+ else:
127
+ raise ValueError(
128
+ f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
129
+ )
130
+
131
+ def __repr__(self):
132
+ return (
133
+ f"TaskOutput(task_name={self.task_name}, "
134
+ f"group_name={self.group_name}, "
135
+ f"version={self.version}, "
136
+ f"n_shot={self.n_shot}, "
137
+ f"task_alias={self.task_alias}, "
138
+ f"group_alias={self.group_alias})"
139
+ )
140
+
141
+
142
+ def get_task_list(task_dict: dict) -> List[TaskOutput]:
143
+ outputs = []
144
+ for task_name, task_obj in task_dict.items():
145
+ if isinstance(task_obj, dict):
146
+ _outputs = get_task_list(task_obj)
147
+ outputs.extend(_outputs)
148
+ else:
149
+ task_output = TaskOutput.from_taskdict(task_name, task_obj)
150
+ outputs.append(task_output)
151
+
152
+ return outputs
153
+
154
+
155
+ def get_subtask_list(task_dict, task_root=None, depth=0):
156
+ subtask_list = {}
157
+ for group_obj, task_obj in task_dict.items():
158
+ if isinstance(group_obj, ConfigurableGroup):
159
+ # group_name = group_obj.group_name
160
+ group_name = group_obj.group_name
161
+ else:
162
+ group_name = group_obj
163
+ if isinstance(task_obj, dict):
164
+ _subtask_list = get_subtask_list(
165
+ task_obj, task_root=group_name, depth=depth + 1
166
+ )
167
+ if task_root:
168
+ subtask_list.setdefault((task_root, depth), []).extend(
169
+ [
170
+ _task
171
+ for (_task, _depth) in _subtask_list.keys()
172
+ if (_depth - 1) == depth
173
+ ]
174
+ )
175
+
176
+ subtask_list = {**subtask_list, **_subtask_list}
177
+ else:
178
+ if isinstance(task_obj, ConfigurableGroup):
179
+ # group_or_task_name = task_obj.group_name
180
+ group_or_task_name = task_obj.group_name
181
+ elif isinstance(task_obj, Task):
182
+ # group_or_task_name = task_obj.task_name
183
+ group_or_task_name = task_obj.task_name
184
+
185
+ if task_root is None:
186
+ subtask_list.setdefault((group_or_task_name, depth), [])
187
+ else:
188
+ subtask_list.setdefault((task_root, depth), []).append(
189
+ group_or_task_name
190
+ )
191
+
192
+ if depth == 0:
193
+ _subtask_list = {}
194
+ for group_key, task_list in subtask_list.items():
195
+ group_name, depth = group_key
196
+ _subtask_list[group_name] = task_list
197
+ subtask_list = _subtask_list
198
+
199
+ return subtask_list
200
+
201
+
202
+ def print_writeout(task) -> None:
203
+ for inst in task.instances:
204
+ # print the prompt for the first few documents
205
+ if inst.doc_id < 1:
206
+ eval_logger.info(
207
+ f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
208
+ \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
209
+ )
210
+ eval_logger.info(f"Request: {str(inst)}")
211
+
212
+
213
+ def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
214
+ if limit is not None:
215
+ limit = (
216
+ int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
217
+ )
218
+ return limit
219
+
220
+
221
+ def prepare_print_tasks(
222
+ task_dict: dict,
223
+ results: dict,
224
+ task_depth=0,
225
+ group_depth=0,
226
+ ) -> Tuple[dict, dict]:
227
+ """
228
+ @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
229
+ value is a list of task names.
230
+ @param results: Dictionary containing the results of each task. Each key is a
231
+ group name and its value is a dictionary of task results.
232
+ @param task_depth: The indentation level for printing the task
233
+ hierarchy. Default is 0.
234
+ @param group_depth: The indentation level for printing the group
235
+ hierarchy. Default is 0.
236
+ @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
237
+ aggregated results for each task, and groups_agg contains aggregated results for each group.
238
+
239
+ Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
240
+ """
241
+
242
+ def _sort_task_dict(task_dict):
243
+ """
244
+ Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
245
+ Required so that we end up sorting within each sub-header correctly.
246
+ """
247
+
248
+ return dict(
249
+ sorted(
250
+ task_dict.items(),
251
+ key=lambda item: item[0].group_name
252
+ if isinstance(item[0], ConfigurableGroup)
253
+ else item[0],
254
+ )
255
+ )
256
+
257
+ task_agg = collections.defaultdict(dict)
258
+ group_agg = collections.defaultdict(dict)
259
+ task_dict = _sort_task_dict(task_dict)
260
+ for task_or_group_name, task_or_group_obj in task_dict.items():
261
+ tab_string = " " * task_depth + "- " if task_depth > 0 else ""
262
+ if isinstance(task_or_group_name, ConfigurableGroup):
263
+ # string_name = task_or_group_name.group_name
264
+ name = task_or_group_name.group_name
265
+ from_configurable_group = True
266
+ task_or_group_obj = _sort_task_dict(task_or_group_obj)
267
+ elif isinstance(task_or_group_name, str):
268
+ name = task_or_group_name
269
+ if isinstance(task_or_group_obj, Task):
270
+ # string_name = task_or_group_obj.task_name
271
+ name = task_or_group_obj.task_name
272
+ from_configurable_group = False
273
+
274
+ task_agg[name] = results[name].copy()
275
+ if from_configurable_group:
276
+ if task_or_group_name.group_alias is not None:
277
+ alias = task_or_group_name.group_alias
278
+ else:
279
+ alias = task_or_group_name.group
280
+ else:
281
+ if "alias" in task_agg[name]:
282
+ alias = task_agg[name]["alias"]
283
+ else:
284
+ alias = name
285
+
286
+ task_agg[name]["alias"] = tab_string + alias
287
+ if "samples" in task_agg[name]:
288
+ task_agg[name].pop("samples")
289
+
290
+ if from_configurable_group and (" " not in results[name]):
291
+ group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
292
+ group_agg[name] = results[name].copy()
293
+ group_agg[name]["alias"] = group_tab_string + alias
294
+ if "samples" in group_agg[name]:
295
+ group_agg[name].pop("samples")
296
+
297
+ if isinstance(task_or_group_obj, dict):
298
+ task_depth += 1
299
+ group_depth += 1
300
+ _task_agg, _group_agg = prepare_print_tasks(
301
+ task_or_group_obj, results, task_depth, group_depth
302
+ )
303
+ task_agg = {
304
+ **task_agg,
305
+ **_task_agg,
306
+ }
307
+ group_agg = {**group_agg, **_group_agg}
308
+ task_depth -= 1
309
+ group_depth -= 1
310
+ return task_agg, group_agg
311
+
312
+
313
+ def consolidate_results(
314
+ eval_tasks: List[TaskOutput],
315
+ ) -> Tuple[dict, dict, dict, dict, dict, dict]:
316
+ """
317
+ @param eval_tasks: list(TaskOutput).
318
+ @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
319
+
320
+ Consolidates the results of multiple evaluation tasks into a single structure.
321
+
322
+ The method iterates over each evaluation instance and extracts relevant information to create the consolidated
323
+ results structure. The consolidated results structure has the following properties:
324
+
325
+ - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
326
+ metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
327
+ aliases specified in the task configuration.
328
+ - samples: A defaultdict with task names as keys and lists of log samples as values.
329
+ - configs: A defaultdict with task names as keys and task configurations as values.
330
+ - versions: A defaultdict with task names as keys and task versions as values.
331
+ - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
332
+ - higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
333
+ for each metric as values.
334
+
335
+ The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
336
+ """
337
+ # stores the final result for each task, for each metric/filter pair.
338
+ results = collections.defaultdict(dict)
339
+ # logs info about each document evaluated.
340
+ samples = collections.defaultdict(list)
341
+ # store num-fewshot value per task
342
+ num_fewshot = collections.defaultdict(int)
343
+ # Tracks the YAML configs of all chosen task
344
+ configs = collections.defaultdict(dict)
345
+ # Tracks each task's version.
346
+ versions = collections.defaultdict(dict)
347
+ # Track `higher_is_better` for each metric
348
+ higher_is_better = collections.defaultdict(dict)
349
+
350
+ for task_output in eval_tasks:
351
+ if "task_alias" in (task_config := task_output.task_config):
352
+ results[task_output.task_name]["alias"] = task_config["task_alias"]
353
+ else:
354
+ results[task_output.task_name]["alias"] = task_output.task_name
355
+ if group_alias := task_output.group_alias:
356
+ if group_alias not in results and (group_name := task_output.group_name):
357
+ results[group_name]["alias"] = group_alias
358
+ num_fewshot[task_output.task_name] = task_output.n_shot
359
+ configs[task_output.task_name] = task_output.task_config
360
+ versions[task_output.task_name] = task_output.version
361
+ samples[task_output.task_name] = task_output.logged_samples
362
+ higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
363
+ for (metric, filter_key), items in task_output.sample_metrics.items():
364
+ metric_key = f"{metric},{filter_key}"
365
+ results[task_output.task_name][metric_key] = task_output.agg_metrics[
366
+ metric_key
367
+ ]
368
+ results[task_output.task_name]["samples"] = task_output.sample_len
369
+ results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
370
+ task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
371
+ )
372
+ return results, samples, configs, versions, num_fewshot, higher_is_better
373
+
374
+
375
+ def consolidate_group_results(
376
+ results,
377
+ versions,
378
+ task_dict,
379
+ task_root=None,
380
+ show_group_table=False,
381
+ task_aggregation_list=None,
382
+ ) -> Tuple[dict, dict, bool, Union[None,]]:
383
+ """
384
+ (Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
385
+
386
+ @return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
387
+
388
+ - results: A defaultdict with task names (and, after this function is called, group names of
389
+ groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
390
+ - versions: A defaultdict with task names (and, after this function is called, group names of
391
+ groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
392
+ - show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
393
+ - task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
394
+
395
+ The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
396
+ In the top-level invocation of this function, task_aggregation_list is ignored.
397
+ """
398
+ if task_root is None:
399
+ task_root = {}
400
+
401
+ if task_aggregation_list is None:
402
+ task_aggregation_list = {}
403
+
404
+ for group_or_task, group_or_task_info in task_dict.items():
405
+ # Convert to string
406
+ if isinstance(group_or_task, ConfigurableGroup):
407
+ group_config = group_or_task.config
408
+ group_or_task = group_or_task.group_name
409
+ else:
410
+ group_config = None
411
+
412
+ if isinstance(group_or_task_info, Task):
413
+ if task_root:
414
+ task_aggregation_list.setdefault(task_root, []).append(
415
+ group_or_task_info.task_name
416
+ )
417
+ else:
418
+ (
419
+ results,
420
+ versions,
421
+ show_group_table,
422
+ _task_aggregation_list,
423
+ ) = consolidate_group_results(
424
+ results,
425
+ versions,
426
+ group_or_task_info,
427
+ group_or_task,
428
+ show_group_table,
429
+ task_aggregation_list,
430
+ )
431
+ if task_root:
432
+ task_aggregation_list.setdefault(task_root, []).extend(
433
+ task_aggregation_list.get(group_or_task, [])
434
+ )
435
+
436
+ if (group_config is None) or (
437
+ group_config["aggregate_metric_list"] is None
438
+ ):
439
+ results[group_or_task][" "] = " "
440
+ continue
441
+
442
+ if "aggregate_metric_list" in group_config:
443
+ agg_metric_list = group_config["aggregate_metric_list"]
444
+
445
+ show_group_table = show_group_table | bool(
446
+ group_config["aggregate_metric_list"]
447
+ )
448
+
449
+ task_list = _task_aggregation_list[group_or_task]
450
+
451
+ metric_list = list(
452
+ {
453
+ key
454
+ for task in task_list
455
+ for key in results[task].keys()
456
+ if "_stderr" not in key and key not in ["task", "alias", "samples"]
457
+ }
458
+ )
459
+ for metric in metric_list:
460
+ stderr = "_stderr,".join(metric.split(","))
461
+
462
+ # gather metrics, sizes, and stderrs from subtasks
463
+ metrics = [
464
+ results[task][metric]
465
+ for task in task_list
466
+ if metric in results[task]
467
+ ] # TODO: copy?
468
+ stderrs = [
469
+ results[task][stderr]
470
+ for task in task_list
471
+ if stderr in results[task]
472
+ ]
473
+ sizes = [
474
+ results[task]["samples"]
475
+ for task in task_list
476
+ if metric in results[task]
477
+ ]
478
+
479
+ for metric_config in agg_metric_list:
480
+ for filter_name in metric_config["filter_list"]:
481
+ if metric != ",".join([metric_config["metric"], filter_name]):
482
+ continue
483
+
484
+ # compute group's pooled metric and stderr
485
+ if metric_config["aggregation"] == "mean":
486
+ aggregate_fn = aggregate_subtask_metrics
487
+ elif callable(metric_config["aggregation"]):
488
+ aggregate_fn = metric_config["aggregation"]
489
+ else:
490
+ raise ValueError(
491
+ f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
492
+ )
493
+
494
+ results[group_or_task][metric] = aggregate_fn(
495
+ metrics,
496
+ sizes,
497
+ metric_config["weight_by_size"],
498
+ )
499
+ # TODO: calculate groups' metrics using arbitrary agg fns
500
+ if "N/A" in stderrs:
501
+ results[group_or_task][stderr] = "N/A"
502
+ else:
503
+ # NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
504
+ results[group_or_task][stderr] = pooled_sample_stderr(
505
+ stderrs, sizes
506
+ )
507
+
508
+ results[group_or_task]["samples"] = sum(sizes)
509
+ group_metadata = group_config.get("metadata", None)
510
+ if group_metadata is not None:
511
+ versions[group_or_task] = group_metadata.get("version", None)
512
+ # print(results)
513
+ return results, versions, show_group_table, task_aggregation_list
514
+
515
+
516
+ @positional_deprecated
517
+ def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
518
+ """
519
+ Search upward in the directory tree to a maximum of three layers
520
+ to find and return the package root (containing the 'tests' folder)
521
+ """
522
+ cur_path = start_path.resolve()
523
+ max_layers = 3
524
+ for _ in range(max_layers):
525
+ if (cur_path / "tests" / "test_version_stable.py").exists():
526
+ return cur_path
527
+ else:
528
+ cur_path = cur_path.parent.resolve()
529
+ raise FileNotFoundError(
530
+ f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
531
+ )
532
+
533
+
534
+ @positional_deprecated
535
+ def run_task_tests(task_list: List[str]):
536
+ """
537
+ Find the package root and run the tests for the given tasks
538
+ """
539
+ import pytest
540
+
541
+ package_root = find_test_root(start_path=pathlib.Path(__file__))
542
+ task_string = " or ".join(task_list)
543
+ args = [
544
+ f"{package_root}/tests/test_version_stable.py",
545
+ f"--rootdir={package_root}",
546
+ "-k",
547
+ f"{task_string}",
548
+ ]
549
+ sys.path.append(str(package_root))
550
+ pytest_return_val = pytest.main(args)
551
+ if pytest_return_val:
552
+ raise ValueError(
553
+ f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
554
+ )
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import List
3
+
4
+ from lm_eval.api.filter import FilterEnsemble
5
+ from lm_eval.api.registry import get_filter
6
+
7
+ from . import custom, extraction, selection, transformation
8
+
9
+
10
+ def build_filter_ensemble(
11
+ filter_name: str, components: List[List[str]]
12
+ ) -> FilterEnsemble:
13
+ """
14
+ Create a filtering pipeline.
15
+ """
16
+ filters = []
17
+ for function, kwargs in components:
18
+ if kwargs is None:
19
+ kwargs = {}
20
+ # create a filter given its name in the registry
21
+ f = partial(get_filter(function), **kwargs)
22
+ # add the filter as a pipeline step
23
+ filters.append(f)
24
+
25
+ return FilterEnsemble(name=filter_name, filters=filters)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/custom.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("custom")
6
+ class CustomFilter(Filter):
7
+ """
8
+ Custom filter that applies a custom, user-defined function to the model responses.
9
+ """
10
+
11
+ def __init__(self, **kwargs) -> None:
12
+ self.filter_fn = kwargs.pop("filter_fn")
13
+
14
+ super().__init__(**kwargs)
15
+
16
+ def apply(self, resps, docs):
17
+ return self.filter_fn(resps, docs)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/decontamination.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("decontaminate")
6
+ class DecontaminationFilter(Filter):
7
+ """
8
+ A filter which evaluates
9
+ """
10
+
11
+ name = "track_decontamination"
12
+
13
+ def __init__(self, path) -> None:
14
+ """
15
+
16
+ TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
17
+ should further cache result on a given (task_name, doc_id)
18
+ """
19
+ self._decontam_results = None
20
+
21
+ def apply(self, resps, docs) -> None:
22
+ """
23
+ Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
24
+ """
25
+ pass
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/extraction.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import unicodedata
4
+
5
+ from lm_eval.api.filter import Filter
6
+ from lm_eval.api.registry import register_filter
7
+
8
+
9
+ @register_filter("regex")
10
+ class RegexFilter(Filter):
11
+ """A filter that extracts values from text using regex pattern matching.
12
+
13
+ This filter applies a regex pattern to each model response and extracts matched values.
14
+ If no match is found, returns a fallback value. Useful for extracting structured data
15
+ (like numbers) from unstructured model outputs.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
21
+ group_select: int = 0,
22
+ fallback: str = "[invalid]",
23
+ ) -> None:
24
+ """
25
+ pass a string `regex` to run `re.compile(r"regex")` on.
26
+ `fallback` defines the output returned if no matches for the regex are located.
27
+ """
28
+ self.regex_pattern = regex_pattern
29
+ self.regex = re.compile(regex_pattern)
30
+ self.group_select = group_select
31
+ self.fallback = fallback
32
+
33
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
34
+ # here, we assume we have a list, in which each element is
35
+ # a list of model responses for some particular input/target pair.
36
+ # so we process each of these (same input/target response sets)
37
+ # independently (and keep them a list.)
38
+ def filter_set(inst):
39
+ filtered = []
40
+ for resp in inst:
41
+ match = self.regex.findall(resp)
42
+ if match:
43
+ match = match[self.group_select]
44
+ if isinstance(match, tuple):
45
+ match = [m for m in match if m]
46
+ if match:
47
+ match = match[0]
48
+ else:
49
+ match = self.fallback
50
+ match = match.strip()
51
+ else:
52
+ match = self.fallback
53
+ filtered.append(match)
54
+ return filtered
55
+
56
+ filtered_resps = list(map(lambda x: filter_set(x), resps))
57
+
58
+ return filtered_resps
59
+
60
+
61
+ @register_filter("remove_whitespace")
62
+ class WhitespaceFilter(Filter):
63
+ """Filters out leading whitespace from responses."""
64
+
65
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
66
+ def filter_set(inst):
67
+ filtered_resp = []
68
+ for resp in inst:
69
+ resp = resp.lstrip()
70
+ filtered_resp.append(resp)
71
+ return filtered_resp
72
+
73
+ filtered_resps = [filter_set(resp) for resp in resps]
74
+
75
+ return filtered_resps
76
+
77
+
78
+ @register_filter("multi_choice_regex")
79
+ class MultiChoiceRegexFilter(RegexFilter):
80
+ """
81
+ A filter used to extract a model's answer on multiple choice questions with
82
+ letter answers. assumes each document has a "choices" field
83
+ containing the list of answer choices and that the answer label symbols
84
+ are of the form (A), (B), (C), ... or A, B, C.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
90
+ group_select=0,
91
+ fallback: str = "[invalid]",
92
+ ignore_case=False,
93
+ ignore_punctuation=False,
94
+ regexes_to_ignore=None,
95
+ ) -> None:
96
+ """
97
+ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
98
+ - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
99
+ - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
100
+ group_select: Selects the (group_select)th match from the findall result.
101
+ ignore_case: Ignores the case during step 1 matching
102
+ ignore_punctuation: Remove the punctuation during step 1 matching
103
+ regexes_to_ignore: Remove these regexes during step 1 matching
104
+ """
105
+ super().__init__(regex_pattern, group_select, fallback)
106
+ self.ignore_case = ignore_case
107
+ self.ignore_punctuation = ignore_punctuation
108
+ self.regexes_to_ignore = regexes_to_ignore
109
+
110
+ def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
111
+ # here, we assume we have a list, in which each element is
112
+ # a list of model responses for some particular input/target pair.
113
+ # so we process each of these (same input/target response sets)
114
+ # independently (and keep them a list.)
115
+
116
+ def find_match(regex, resp, convert_dict={}):
117
+ match = regex.findall(resp)
118
+ if match:
119
+ match = match[self.group_select]
120
+ if isinstance(match, tuple):
121
+ match = [m for m in match if m][0]
122
+ match = match.strip()
123
+ if match and match in convert_dict:
124
+ match = convert_dict[match]
125
+ return match
126
+
127
+ punct_tbl = dict.fromkeys(
128
+ i
129
+ for i in range(sys.maxunicode)
130
+ if unicodedata.category(chr(i)).startswith("P")
131
+ )
132
+
133
+ def filter_ignores(st):
134
+ if self.regexes_to_ignore is not None:
135
+ for s in self.regexes_to_ignore:
136
+ st = re.sub(s, "", st)
137
+
138
+ if self.ignore_case:
139
+ st = st.lower()
140
+
141
+ if self.ignore_punctuation:
142
+ # https://stackoverflow.com/a/266162
143
+ st = st.translate(punct_tbl)
144
+ return st
145
+
146
+ filtered_resps = []
147
+
148
+ for r, doc in zip(resps, docs):
149
+ fallback_regexes = []
150
+ choice_to_alpha = {}
151
+ next_alpha = "A"
152
+
153
+ without_paren_fallback_regexes = []
154
+ without_paren_to_target = {}
155
+
156
+ choices = doc["choices"]
157
+ for c in choices:
158
+ m = filter_ignores(c.strip())
159
+ fallback_regexes.append(f"{re.escape(m)}")
160
+ choice_to_alpha[m] = f"({next_alpha})"
161
+
162
+ without_paren_fallback_regexes.append(next_alpha)
163
+ without_paren_to_target[next_alpha] = f"({next_alpha})"
164
+
165
+ next_alpha = chr(ord(next_alpha) + 1)
166
+ fallback_regex = re.compile("|".join(fallback_regexes))
167
+ without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
168
+ without_paren_fallback_regex = re.compile(
169
+ rf":[\s]*({without_paren_fallback_regex})"
170
+ )
171
+
172
+ filtered = []
173
+ for resp in r:
174
+ match = find_match(self.regex, resp)
175
+ if not match:
176
+ match = find_match(
177
+ fallback_regex, filter_ignores(resp), choice_to_alpha
178
+ )
179
+ if not match:
180
+ match = find_match(
181
+ without_paren_fallback_regex, resp, without_paren_to_target
182
+ )
183
+ if not match:
184
+ match = self.fallback
185
+ filtered.append(match)
186
+ filtered_resps.append(filtered)
187
+
188
+ return filtered_resps
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/selection.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ from lm_eval.api.filter import Filter
4
+ from lm_eval.api.registry import register_filter
5
+
6
+
7
+ # TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
8
+ # that takes an input and returns a scalar and then should select the max reward,
9
+ # or should implement different filters for different ways of handling a reward model's inference.
10
+
11
+
12
+ @register_filter("take_first")
13
+ class TakeFirstFilter(Filter):
14
+ def __init__(self) -> None:
15
+ """
16
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
17
+ """
18
+
19
+ def apply(self, resps, docs):
20
+ """
21
+ Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
22
+ """
23
+ return map(lambda r: r[0], resps)
24
+
25
+
26
+ @register_filter("take_first_k")
27
+ class TakeKFilter(Filter):
28
+ def __init__(self, **kwargs) -> None:
29
+ self.k = kwargs.pop("k")
30
+
31
+ super().__init__(**kwargs)
32
+
33
+ def apply(self, resps, docs):
34
+ # need resp to be subscriptable to check below
35
+ resps = list(resps)
36
+ # check we have at least k responses per doc, else we can't take the first k
37
+ assert len(resps[0]) >= self.k, (
38
+ f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
39
+ )
40
+ return map(lambda r: r[: self.k], resps)
41
+
42
+
43
+ @register_filter("majority_vote")
44
+ class MajorityVoteFilter(Filter):
45
+ def __init__(self) -> None:
46
+ """
47
+ Can define custom behavior here, if an individual instantiation of a Filter class should have state.
48
+ """
49
+
50
+ def apply(self, resps, docs):
51
+ """
52
+ Each entry of `resps` is a list of model responses.
53
+ We select the response that occurs most frequently in each entry of `resps`.
54
+ """
55
+
56
+ def select_majority(resp):
57
+ counts = Counter(resp)
58
+ vote = counts.most_common(1)[0][0]
59
+ return vote
60
+
61
+ return map(lambda r: [select_majority(r)], resps)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/filters/transformation.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lm_eval.api.filter import Filter
2
+ from lm_eval.api.registry import register_filter
3
+
4
+
5
+ @register_filter("lowercase")
6
+ class LowercaseFilter(Filter):
7
+ def __init__(self) -> None:
8
+ pass
9
+
10
+ def apply(self, resps, docs):
11
+ def filter_set(inst):
12
+ return [resp.lower() for resp in inst]
13
+
14
+ return [filter_set(resp) for resp in resps]
15
+
16
+
17
+ @register_filter("uppercase")
18
+ class UppercaseFilter(Filter):
19
+ def __init__(self) -> None:
20
+ pass
21
+
22
+ def apply(self, resps, docs):
23
+ def filter_set(inst):
24
+ return [resp.upper() for resp in inst]
25
+
26
+ return [filter_set(resp) for resp in resps]
27
+
28
+
29
+ @register_filter("map")
30
+ class MapFilter(Filter):
31
+ def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
32
+ """
33
+ Initializes the MapFilter with a given mapping dictionary and default value.
34
+
35
+ Args:
36
+ - mapping_dict (dict): A dictionary containing the key-value mappings.
37
+ Default is an empty dictionary.
38
+ - default_value (Any): The value to be returned when a key is not found in the mapping_dict.
39
+ Default is None.
40
+
41
+ Example:
42
+ mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
43
+ """
44
+ if mapping_dict is None:
45
+ mapping_dict = {}
46
+ assert isinstance(mapping_dict, dict), (
47
+ "Provided mapping_dict is not a dictionary"
48
+ )
49
+ self.mapping_dict = mapping_dict
50
+ self.default_value = default_value
51
+
52
+ def apply(self, resps, docs):
53
+ def filter_set(inst):
54
+ return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
55
+
56
+ return [filter_set(resp) for resp in resps]
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .evaluation_tracker import EvaluationTracker
2
+ from .wandb_logger import WandbLogger
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/evaluation_tracker.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ import time
6
+ from collections import defaultdict
7
+ from dataclasses import asdict, dataclass
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ from datasets import load_dataset
12
+ from datasets.utils.metadata import MetadataConfigs
13
+ from huggingface_hub import (
14
+ DatasetCard,
15
+ DatasetCardData,
16
+ HfApi,
17
+ hf_hub_url,
18
+ )
19
+ from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
20
+
21
+ from lm_eval.utils import (
22
+ get_file_datetime,
23
+ get_file_task_name,
24
+ get_results_filenames,
25
+ get_sample_results_filenames,
26
+ handle_non_serializable,
27
+ hash_string,
28
+ sanitize_list,
29
+ sanitize_model_name,
30
+ sanitize_task_name,
31
+ )
32
+
33
+
34
+ eval_logger = logging.getLogger(__name__)
35
+
36
+
37
+ @dataclass(init=False)
38
+ class GeneralConfigTracker:
39
+ """
40
+ Tracker for the evaluation parameters.
41
+
42
+ Attributes:
43
+ model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
44
+ model_name (str): Name of the model.
45
+ model_name_sanitized (str): Sanitized model name for directory creation.
46
+ start_time (float): Start time of the experiment. Logged at class init.
47
+ end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
48
+ total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
49
+ """
50
+
51
+ model_source: str = None
52
+ model_name: str = None
53
+ model_name_sanitized: str = None
54
+ system_instruction: str = None
55
+ system_instruction_sha: str = None
56
+ fewshot_as_multiturn: bool = None
57
+ chat_template: str = None
58
+ chat_template_sha: str = None
59
+ start_time: float = None
60
+ end_time: float = None
61
+ total_evaluation_time_seconds: str = None
62
+
63
+ def __init__(self) -> None:
64
+ """Starts the evaluation timer."""
65
+ self.start_time = time.perf_counter()
66
+
67
+ @staticmethod
68
+ def _get_model_name(model_args: str) -> str:
69
+ """Extracts the model name from the model arguments."""
70
+
71
+ def extract_model_name(model_args: str, key: str) -> str:
72
+ """Extracts the model name from the model arguments using a key."""
73
+ args_after_key = model_args.split(key)[1]
74
+ return args_after_key.split(",")[0]
75
+
76
+ # order does matter, e.g. peft and delta are provided together with pretrained
77
+ prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
78
+ for prefix in prefixes:
79
+ if prefix in model_args:
80
+ return extract_model_name(model_args, prefix)
81
+ return ""
82
+
83
+ def log_experiment_args(
84
+ self,
85
+ model_source: str,
86
+ model_args: str,
87
+ system_instruction: str,
88
+ chat_template: str,
89
+ fewshot_as_multiturn: bool,
90
+ ) -> None:
91
+ """Logs model parameters and job ID."""
92
+ self.model_source = model_source
93
+ self.model_name = GeneralConfigTracker._get_model_name(model_args)
94
+ self.model_name_sanitized = sanitize_model_name(self.model_name)
95
+ self.system_instruction = system_instruction
96
+ self.system_instruction_sha = (
97
+ hash_string(system_instruction) if system_instruction else None
98
+ )
99
+ self.chat_template = chat_template
100
+ self.chat_template_sha = hash_string(chat_template) if chat_template else None
101
+ self.fewshot_as_multiturn = fewshot_as_multiturn
102
+
103
+ def log_end_time(self) -> None:
104
+ """Logs the end time of the evaluation and calculates the total evaluation time."""
105
+ self.end_time = time.perf_counter()
106
+ self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
107
+
108
+
109
+ class EvaluationTracker:
110
+ """
111
+ Keeps track and saves relevant information of the evaluation process.
112
+ Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ output_path: str = None,
118
+ hub_results_org: str = "",
119
+ hub_repo_name: str = "",
120
+ details_repo_name: str = "",
121
+ results_repo_name: str = "",
122
+ push_results_to_hub: bool = False,
123
+ push_samples_to_hub: bool = False,
124
+ public_repo: bool = False,
125
+ token: str = "",
126
+ leaderboard_url: str = "",
127
+ point_of_contact: str = "",
128
+ gated: bool = False,
129
+ ) -> None:
130
+ """
131
+ Creates all the necessary loggers for evaluation tracking.
132
+
133
+ Args:
134
+ output_path (str): Path to save the results. If not provided, the results won't be saved.
135
+ hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
136
+ hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
137
+ details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
138
+ result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
139
+ push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
140
+ push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
141
+ public_repo (bool): Whether to push the results to a public or private repository.
142
+ token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
143
+ leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
144
+ point_of_contact (str): Contact information on the Hugging Face hub dataset card.
145
+ gated (bool): Whether to gate the repository.
146
+ """
147
+ self.general_config_tracker = GeneralConfigTracker()
148
+
149
+ self.output_path = output_path
150
+ self.push_results_to_hub = push_results_to_hub
151
+ self.push_samples_to_hub = push_samples_to_hub
152
+ self.public_repo = public_repo
153
+ self.leaderboard_url = leaderboard_url
154
+ self.point_of_contact = point_of_contact
155
+ self.api = HfApi(token=token) if token else None
156
+ self.gated_repo = gated
157
+
158
+ if not self.api and (push_results_to_hub or push_samples_to_hub):
159
+ raise ValueError(
160
+ "Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. "
161
+ "Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable."
162
+ )
163
+
164
+ if (
165
+ self.api
166
+ and hub_results_org == ""
167
+ and (push_results_to_hub or push_samples_to_hub)
168
+ ):
169
+ hub_results_org = self.api.whoami()["name"]
170
+ eval_logger.warning(
171
+ f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
172
+ )
173
+
174
+ if hub_repo_name == "":
175
+ details_repo_name = (
176
+ details_repo_name if details_repo_name != "" else "lm-eval-results"
177
+ )
178
+ results_repo_name = (
179
+ results_repo_name if results_repo_name != "" else details_repo_name
180
+ )
181
+ else:
182
+ details_repo_name = hub_repo_name
183
+ results_repo_name = hub_repo_name
184
+ eval_logger.warning(
185
+ "hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
186
+ )
187
+
188
+ self.details_repo = f"{hub_results_org}/{details_repo_name}"
189
+ self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
190
+ self.results_repo = f"{hub_results_org}/{results_repo_name}"
191
+ self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
192
+
193
+ def save_results_aggregated(
194
+ self,
195
+ results: dict,
196
+ samples: dict,
197
+ ) -> None:
198
+ """
199
+ Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
200
+
201
+ Args:
202
+ results (dict): The aggregated results to save.
203
+ samples (dict): The samples results to save.
204
+ """
205
+ self.general_config_tracker.log_end_time()
206
+
207
+ if self.output_path:
208
+ try:
209
+ eval_logger.info("Saving results aggregated")
210
+
211
+ # calculate cumulative hash for each task - only if samples are provided
212
+ task_hashes = {}
213
+ if samples:
214
+ for task_name, task_samples in samples.items():
215
+ sample_hashes = [
216
+ s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
217
+ for s in task_samples
218
+ ]
219
+ task_hashes[task_name] = hash_string("".join(sample_hashes))
220
+
221
+ # update initial results dict
222
+ results.update({"task_hashes": task_hashes})
223
+ results.update(asdict(self.general_config_tracker))
224
+ dumped = json.dumps(
225
+ results,
226
+ indent=2,
227
+ default=handle_non_serializable,
228
+ ensure_ascii=False,
229
+ )
230
+
231
+ path = Path(self.output_path if self.output_path else Path.cwd())
232
+ path = path.joinpath(self.general_config_tracker.model_name_sanitized)
233
+ path.mkdir(parents=True, exist_ok=True)
234
+
235
+ self.date_id = datetime.now().isoformat().replace(":", "-")
236
+ file_results_aggregated = path.joinpath(f"results_{self.date_id}.json")
237
+ file_results_aggregated.open("w", encoding="utf-8").write(dumped)
238
+
239
+ if self.api and self.push_results_to_hub:
240
+ repo_id = (
241
+ self.results_repo
242
+ if self.public_repo
243
+ else self.results_repo_private
244
+ )
245
+ self.api.create_repo(
246
+ repo_id=repo_id,
247
+ repo_type="dataset",
248
+ private=not self.public_repo,
249
+ exist_ok=True,
250
+ )
251
+ self.api.upload_file(
252
+ repo_id=repo_id,
253
+ path_or_fileobj=str(
254
+ path.joinpath(f"results_{self.date_id}.json")
255
+ ),
256
+ path_in_repo=os.path.join(
257
+ self.general_config_tracker.model_name,
258
+ f"results_{self.date_id}.json",
259
+ ),
260
+ repo_type="dataset",
261
+ commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
262
+ )
263
+ eval_logger.info(
264
+ "Successfully pushed aggregated results to the Hugging Face Hub. "
265
+ f"You can find them at: {repo_id}"
266
+ )
267
+
268
+ except Exception as e:
269
+ eval_logger.warning("Could not save results aggregated")
270
+ eval_logger.info(repr(e))
271
+ else:
272
+ eval_logger.info(
273
+ "Output path not provided, skipping saving results aggregated"
274
+ )
275
+
276
+ def save_results_samples(
277
+ self,
278
+ task_name: str,
279
+ samples: dict,
280
+ ) -> None:
281
+ """
282
+ Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
283
+
284
+ Args:
285
+ task_name (str): The task name to save the samples for.
286
+ samples (dict): The samples results to save.
287
+ """
288
+ if self.output_path:
289
+ try:
290
+ eval_logger.info(f"Saving per-sample results for: {task_name}")
291
+
292
+ path = Path(self.output_path if self.output_path else Path.cwd())
293
+ path = path.joinpath(self.general_config_tracker.model_name_sanitized)
294
+ path.mkdir(parents=True, exist_ok=True)
295
+
296
+ file_results_samples = path.joinpath(
297
+ f"samples_{task_name}_{self.date_id}.jsonl"
298
+ )
299
+
300
+ for sample in samples:
301
+ # we first need to sanitize arguments and resps
302
+ # otherwise we won't be able to load the dataset
303
+ # using the datasets library
304
+ arguments = {}
305
+ for i, arg in enumerate(sample["arguments"]):
306
+ arguments[f"gen_args_{i}"] = {}
307
+ for j, tmp in enumerate(arg):
308
+ arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
309
+
310
+ sample["resps"] = sanitize_list(sample["resps"])
311
+ sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
312
+ sample["arguments"] = arguments
313
+ sample["target"] = str(sample["target"])
314
+
315
+ sample_dump = (
316
+ json.dumps(
317
+ sample,
318
+ default=handle_non_serializable,
319
+ ensure_ascii=False,
320
+ )
321
+ + "\n"
322
+ )
323
+
324
+ with open(file_results_samples, "a", encoding="utf-8") as f:
325
+ f.write(sample_dump)
326
+
327
+ if self.api and self.push_samples_to_hub:
328
+ repo_id = (
329
+ self.details_repo
330
+ if self.public_repo
331
+ else self.details_repo_private
332
+ )
333
+ self.api.create_repo(
334
+ repo_id=repo_id,
335
+ repo_type="dataset",
336
+ private=not self.public_repo,
337
+ exist_ok=True,
338
+ )
339
+ try:
340
+ if self.gated_repo:
341
+ headers = build_hf_headers()
342
+ r = get_session().put(
343
+ url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
344
+ headers=headers,
345
+ json={"gated": "auto"},
346
+ )
347
+ hf_raise_for_status(r)
348
+ except Exception as e:
349
+ eval_logger.warning("Could not gate the repository")
350
+ eval_logger.info(repr(e))
351
+ self.api.upload_folder(
352
+ repo_id=repo_id,
353
+ folder_path=str(path),
354
+ path_in_repo=self.general_config_tracker.model_name_sanitized,
355
+ repo_type="dataset",
356
+ commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
357
+ )
358
+ eval_logger.info(
359
+ f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. "
360
+ f"You can find them at: {repo_id}"
361
+ )
362
+
363
+ except Exception as e:
364
+ eval_logger.warning("Could not save sample results")
365
+ eval_logger.info(repr(e))
366
+ else:
367
+ eval_logger.info("Output path not provided, skipping saving sample results")
368
+
369
+ def recreate_metadata_card(self) -> None:
370
+ """
371
+ Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
372
+ """
373
+
374
+ eval_logger.info("Recreating metadata card")
375
+ repo_id = self.details_repo if self.public_repo else self.details_repo_private
376
+
377
+ files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
378
+ results_files = get_results_filenames(files_in_repo)
379
+ sample_files = get_sample_results_filenames(files_in_repo)
380
+
381
+ # Build a dictionary to store the latest evaluation datetime for:
382
+ # - Each tested model and its aggregated results
383
+ # - Each task and sample results, if existing
384
+ # i.e. {
385
+ # "org__model_name__gsm8k": "2021-09-01T12:00:00",
386
+ # "org__model_name__ifeval": "2021-09-01T12:00:00",
387
+ # "org__model_name__results": "2021-09-01T12:00:00"
388
+ # }
389
+ latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat())
390
+
391
+ for file_path in sample_files:
392
+ file_path = Path(file_path)
393
+ filename = file_path.name
394
+ model_name = file_path.parent
395
+ task_name = get_file_task_name(filename)
396
+ results_datetime = get_file_datetime(filename)
397
+ task_name_sanitized = sanitize_task_name(task_name)
398
+ # Results and sample results for the same model and task will have the same datetime
399
+ samples_key = f"{model_name}__{task_name_sanitized}"
400
+ results_key = f"{model_name}__results"
401
+ latest_datetime = max(
402
+ latest_task_results_datetime[samples_key],
403
+ results_datetime,
404
+ )
405
+ latest_task_results_datetime[samples_key] = latest_datetime
406
+ latest_task_results_datetime[results_key] = max(
407
+ latest_task_results_datetime[results_key],
408
+ latest_datetime,
409
+ )
410
+
411
+ # Create metadata card
412
+ card_metadata = MetadataConfigs()
413
+
414
+ # Add the latest aggregated results to the metadata card for easy access
415
+ for file_path in results_files:
416
+ file_path = Path(file_path)
417
+ results_filename = file_path.name
418
+ model_name = file_path.parent
419
+ eval_date = get_file_datetime(results_filename)
420
+ eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
421
+ results_filename = Path("**") / Path(results_filename).name
422
+ config_name = f"{model_name}__results"
423
+ sanitized_last_eval_date_results = re.sub(
424
+ r"[^\w\.]", "_", latest_task_results_datetime[config_name]
425
+ )
426
+
427
+ if eval_date_sanitized == sanitized_last_eval_date_results:
428
+ # Ensure that all results files are listed in the metadata card
429
+ current_results = card_metadata.get(config_name, {"data_files": []})
430
+ current_results["data_files"].append(
431
+ {"split": eval_date_sanitized, "path": [str(results_filename)]}
432
+ )
433
+ card_metadata[config_name] = current_results
434
+ # If the results file is the newest, update the "latest" field in the metadata card
435
+ card_metadata[config_name]["data_files"].append(
436
+ {"split": "latest", "path": [str(results_filename)]}
437
+ )
438
+
439
+ # Add the tasks details configs
440
+ for file_path in sample_files:
441
+ file_path = Path(file_path)
442
+ filename = file_path.name
443
+ model_name = file_path.parent
444
+ task_name = get_file_task_name(filename)
445
+ eval_date = get_file_datetime(filename)
446
+ task_name_sanitized = sanitize_task_name(task_name)
447
+ eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
448
+ results_filename = Path("**") / Path(filename).name
449
+ config_name = f"{model_name}__{task_name_sanitized}"
450
+ sanitized_last_eval_date_results = re.sub(
451
+ r"[^\w\.]", "_", latest_task_results_datetime[config_name]
452
+ )
453
+ if eval_date_sanitized == sanitized_last_eval_date_results:
454
+ # Ensure that all sample results files are listed in the metadata card
455
+ current_details_for_task = card_metadata.get(
456
+ config_name, {"data_files": []}
457
+ )
458
+ current_details_for_task["data_files"].append(
459
+ {"split": eval_date_sanitized, "path": [str(results_filename)]}
460
+ )
461
+ card_metadata[config_name] = current_details_for_task
462
+ # If the samples results file is the newest, update the "latest" field in the metadata card
463
+ card_metadata[config_name]["data_files"].append(
464
+ {"split": "latest", "path": [str(results_filename)]}
465
+ )
466
+
467
+ # Get latest results and extract info to update metadata card examples
468
+ latest_datetime = max(latest_task_results_datetime.values())
469
+ latest_model_name = max(
470
+ latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k]
471
+ )
472
+ last_results_file = [
473
+ f for f in results_files if latest_datetime.replace(":", "-") in f
474
+ ][0]
475
+ last_results_file_path = hf_hub_url(
476
+ repo_id=repo_id, filename=last_results_file, repo_type="dataset"
477
+ )
478
+ latest_results_file = load_dataset(
479
+ "json", data_files=last_results_file_path, split="train"
480
+ )
481
+ results_dict = latest_results_file["results"][0]
482
+ new_dictionary = {"all": results_dict}
483
+ new_dictionary.update(results_dict)
484
+ results_string = json.dumps(new_dictionary, indent=4)
485
+
486
+ dataset_summary = (
487
+ "Dataset automatically created during the evaluation run of model "
488
+ )
489
+ if self.general_config_tracker.model_source == "hf":
490
+ dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n"
491
+ else:
492
+ dataset_summary += f"{self.general_config_tracker.model_name}\n"
493
+ dataset_summary += (
494
+ f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
495
+ f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
496
+ 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
497
+ 'An additional configuration "results" store all the aggregated results of the run.\n\n'
498
+ "To load the details from a run, you can for instance do the following:\n"
499
+ )
500
+ if self.general_config_tracker.model_source == "hf":
501
+ dataset_summary += (
502
+ "```python\nfrom datasets import load_dataset\n"
503
+ f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n'
504
+ )
505
+ dataset_summary += (
506
+ "## Latest results\n\n"
507
+ f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
508
+ "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
509
+ 'You find each in the results and the "latest" split for each eval):\n\n'
510
+ f"```python\n{results_string}\n```"
511
+ )
512
+ card_data = DatasetCardData(
513
+ dataset_summary=dataset_summary,
514
+ repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}",
515
+ pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}",
516
+ leaderboard_url=self.leaderboard_url,
517
+ point_of_contact=self.point_of_contact,
518
+ )
519
+ card_metadata.to_dataset_card_data(card_data)
520
+ card = DatasetCard.from_template(
521
+ card_data,
522
+ pretty_name=card_data.pretty_name,
523
+ )
524
+ card.push_to_hub(repo_id, repo_type="dataset")
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import subprocess
5
+ from importlib.metadata import version
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ from torch.utils.collect_env import get_pretty_env_info
11
+ from transformers import __version__ as trans_version
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
18
+ """Remove the ',none' substring from the input_string if it exists at the end.
19
+
20
+ Args:
21
+ input_string (str): The input string from which to remove the ',none' substring.
22
+
23
+ Returns:
24
+ Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
25
+ and a boolean indicating whether the modification was made (True) or not (False).
26
+ """
27
+ # Define the pattern to match ',none' at the end of the string
28
+ pattern = re.compile(r",none$")
29
+
30
+ # Use sub() to replace ',none' with an empty string
31
+ result = re.sub(pattern, "", input_string)
32
+
33
+ # check if the input_string changed
34
+ removed = result != input_string
35
+
36
+ return result, removed
37
+
38
+
39
+ def _handle_non_serializable(o: Any) -> Union[int, str, list]:
40
+ """Handle non-serializable objects by converting them to serializable types.
41
+
42
+ Args:
43
+ o (Any): The object to be handled.
44
+
45
+ Returns:
46
+ Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
47
+ it will be converted to int. If the object is of type set, it will be converted
48
+ to a list. Otherwise, it will be converted to str.
49
+ """
50
+ if isinstance(o, np.int64) or isinstance(o, np.int32):
51
+ return int(o)
52
+ elif isinstance(o, set):
53
+ return list(o)
54
+ else:
55
+ return str(o)
56
+
57
+
58
+ def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
59
+ try:
60
+ git_folder = Path(repo_path, ".git")
61
+ if git_folder.is_file():
62
+ git_folder = Path(
63
+ git_folder.parent,
64
+ git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
65
+ )
66
+ if Path(git_folder, "HEAD").exists():
67
+ head_name = (
68
+ Path(git_folder, "HEAD")
69
+ .read_text(encoding="utf-8")
70
+ .split("\n")[0]
71
+ .split(" ")[-1]
72
+ )
73
+ head_ref = Path(git_folder, head_name)
74
+ git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
75
+ else:
76
+ git_hash = None
77
+ except Exception as err:
78
+ logger.debug(
79
+ f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
80
+ )
81
+ return None
82
+ return git_hash
83
+
84
+
85
+ def get_git_commit_hash():
86
+ """
87
+ Gets the git commit hash of your current repo (if it exists).
88
+ Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
89
+ """
90
+ try:
91
+ git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
92
+ git_hash = git_hash.decode()
93
+ except (subprocess.CalledProcessError, FileNotFoundError):
94
+ # FileNotFoundError occurs when git not installed on system
95
+ git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
96
+ return git_hash
97
+
98
+
99
+ def add_env_info(storage: Dict[str, Any]):
100
+ try:
101
+ pretty_env_info = get_pretty_env_info()
102
+ except Exception as err:
103
+ pretty_env_info = str(err)
104
+ try:
105
+ lm_eval_version = version("lm_eval")
106
+ except Exception as err:
107
+ lm_eval_version = str(err)
108
+ transformers_version = trans_version
109
+ upper_dir_commit = get_commit_from_path(
110
+ Path(os.getcwd(), "..")
111
+ ) # git hash of upper repo if exists
112
+ added_info = {
113
+ "pretty_env_info": pretty_env_info,
114
+ "transformers_version": transformers_version,
115
+ "lm_eval_version": lm_eval_version,
116
+ "upper_git_hash": upper_dir_commit, # in case this repo is submodule
117
+ }
118
+ storage.update(added_info)
119
+
120
+
121
+ def add_tokenizer_info(storage: Dict[str, Any], lm):
122
+ if getattr(lm, "tokenizer", False):
123
+ try:
124
+ tokenizer_info = {
125
+ "tokenizer_pad_token": [
126
+ lm.tokenizer.pad_token,
127
+ str(lm.tokenizer.pad_token_id),
128
+ ],
129
+ "tokenizer_eos_token": [
130
+ lm.tokenizer.eos_token,
131
+ str(lm.tokenizer.eos_token_id),
132
+ ],
133
+ "tokenizer_bos_token": [
134
+ lm.tokenizer.bos_token,
135
+ str(lm.tokenizer.bos_token_id),
136
+ ],
137
+ "eot_token_id": getattr(lm, "eot_token_id", None),
138
+ "max_length": getattr(lm, "max_length", None),
139
+ }
140
+ storage.update(tokenizer_info)
141
+ except Exception as err:
142
+ logger.debug(
143
+ f"Logging detailed tokenizer info failed with {err}, skipping..."
144
+ )
145
+ # seems gguf and textsynth do not have tokenizer
146
+ else:
147
+ logger.debug(
148
+ "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
149
+ )
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/loggers/wandb_logger.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ from typing import Any, Dict, List, Literal, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from packaging.version import Version
9
+
10
+ from lm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_wandb_printer() -> Literal["Printer"]:
17
+ """Returns a wandb printer instance for pretty stdout."""
18
+ from wandb.sdk.lib.printer import new_printer
19
+
20
+ printer = new_printer()
21
+ return printer
22
+
23
+
24
+ class WandbLogger:
25
+ def __init__(self, init_args=None, config_args=None) -> None:
26
+ """Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
27
+
28
+ Args:
29
+ init_args Optional[Dict]: Arguments for init configuration.
30
+ config_args Optional[Dict]: Arguments for config
31
+
32
+ Parse and log the results returned from evaluator.simple_evaluate() with:
33
+ wandb_logger.post_init(results)
34
+ wandb_logger.log_eval_result()
35
+ wandb_logger.log_eval_samples(results["samples"])
36
+ """
37
+ try:
38
+ import wandb
39
+
40
+ assert Version(wandb.__version__) >= Version("0.13.6")
41
+ if Version(wandb.__version__) < Version("0.13.6"):
42
+ wandb.require("report-editing:v0")
43
+ except Exception as e:
44
+ logger.warning(
45
+ "To use the wandb reporting functionality please install wandb>=0.13.6.\n"
46
+ "To install the latest version of wandb run `pip install wandb --upgrade`\n"
47
+ f"{e}"
48
+ )
49
+
50
+ self.wandb_args: Dict[str, Any] = init_args or {}
51
+ self.wandb_config_args: Dict[str, Any] = config_args or {}
52
+
53
+ # pop the step key from the args to save for all logging calls
54
+ self.step = self.wandb_args.pop("step", None)
55
+
56
+ # initialize a W&B run
57
+ if wandb.run is None:
58
+ self.run = wandb.init(**self.wandb_args)
59
+ if self.wandb_config_args:
60
+ self.run.config.update(self.wandb_config_args)
61
+ else:
62
+ self.run = wandb.run
63
+
64
+ self.printer = get_wandb_printer()
65
+
66
+ def post_init(self, results: Dict[str, Any]) -> None:
67
+ self.results: Dict[str, Any] = copy.deepcopy(results)
68
+ self.task_names: List[str] = list(results.get("results", {}).keys())
69
+ self.group_names: List[str] = list(results.get("groups", {}).keys())
70
+
71
+ def _get_config(self) -> Dict[str, Any]:
72
+ """Get configuration parameters."""
73
+ self.task_configs = self.results.get("configs", {})
74
+ cli_configs = self.results.get("config", {})
75
+ configs = {
76
+ "task_configs": self.task_configs,
77
+ "cli_configs": cli_configs,
78
+ }
79
+
80
+ return configs
81
+
82
+ def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
83
+ """Sanitize the results dictionary."""
84
+ _results = copy.deepcopy(self.results.get("results", dict()))
85
+
86
+ # Remove None from the metric string name
87
+ tmp_results = copy.deepcopy(_results)
88
+ for task_name in self.task_names:
89
+ task_result = tmp_results.get(task_name, dict())
90
+ for metric_name, metric_value in task_result.items():
91
+ _metric_name, removed = remove_none_pattern(metric_name)
92
+ if removed:
93
+ _results[task_name][_metric_name] = metric_value
94
+ _results[task_name].pop(metric_name)
95
+
96
+ # remove string valued keys from the results dict
97
+ wandb_summary = {}
98
+ for task in self.task_names:
99
+ task_result = _results.get(task, dict())
100
+ for metric_name, metric_value in task_result.items():
101
+ if isinstance(metric_value, str):
102
+ wandb_summary[f"{task}/{metric_name}"] = metric_value
103
+
104
+ for summary_metric, summary_value in wandb_summary.items():
105
+ _task, _summary_metric = summary_metric.split("/")
106
+ _results[_task].pop(_summary_metric)
107
+
108
+ tmp_results = copy.deepcopy(_results)
109
+ for task_name, task_results in tmp_results.items():
110
+ for metric_name, metric_value in task_results.items():
111
+ _results[f"{task_name}/{metric_name}"] = metric_value
112
+ _results[task_name].pop(metric_name)
113
+ for task in self.task_names:
114
+ _results.pop(task)
115
+
116
+ return wandb_summary, _results
117
+
118
+ def _log_results_as_table(self) -> None:
119
+ """Generate and log evaluation results as a table to W&B."""
120
+ columns = [
121
+ "Version",
122
+ "Filter",
123
+ "num_fewshot",
124
+ "Metric",
125
+ "Value",
126
+ "Stderr",
127
+ ]
128
+
129
+ def make_table(columns: List[str], key: str = "results"):
130
+ import wandb
131
+
132
+ table = wandb.Table(columns=columns)
133
+ results = copy.deepcopy(self.results)
134
+
135
+ for k, dic in results.get(key).items():
136
+ if k in self.group_names and not key == "groups":
137
+ continue
138
+ version = results.get("versions").get(k)
139
+ if version == "N/A":
140
+ version = None
141
+ n = results.get("n-shot").get(k)
142
+
143
+ for (mf), v in dic.items():
144
+ m, _, f = mf.partition(",")
145
+ if m.endswith("_stderr"):
146
+ continue
147
+ if m == "alias":
148
+ continue
149
+
150
+ if m + "_stderr" + "," + f in dic:
151
+ se = dic[m + "_stderr" + "," + f]
152
+ if se != "N/A":
153
+ se = "%.4f" % se
154
+ table.add_data(*[k, version, f, n, m, str(v), str(se)])
155
+ else:
156
+ table.add_data(*[k, version, f, n, m, str(v), ""])
157
+
158
+ return table
159
+
160
+ # log the complete eval result to W&B Table
161
+ table = make_table(["Tasks"] + columns, "results")
162
+ self.run.log({"evaluation/eval_results": table}, step=self.step)
163
+
164
+ if "groups" in self.results.keys():
165
+ table = make_table(["Groups"] + columns, "groups")
166
+ self.run.log({"evaluation/group_eval_results": table}, step=self.step)
167
+
168
+ def _log_results_as_artifact(self) -> None:
169
+ """Log results as JSON artifact to W&B."""
170
+ import wandb
171
+
172
+ dumped = json.dumps(
173
+ self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
174
+ )
175
+ artifact = wandb.Artifact("results", type="eval_results")
176
+ with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
177
+ f.write(dumped)
178
+ self.run.log_artifact(artifact)
179
+
180
+ def log_eval_result(self) -> None:
181
+ """Log evaluation results to W&B."""
182
+ # Log configs to wandb
183
+ configs = self._get_config()
184
+ self.run.config.update(configs, allow_val_change=self.step is not None)
185
+
186
+ wandb_summary, self.wandb_results = self._sanitize_results_dict()
187
+ # update wandb.run.summary with items that were removed
188
+ self.run.summary.update(wandb_summary)
189
+ # Log the evaluation metrics to wandb
190
+ self.run.log(self.wandb_results, step=self.step)
191
+ # Log the evaluation metrics as W&B Table
192
+ self._log_results_as_table()
193
+ # Log the results dict as json to W&B Artifacts
194
+ self._log_results_as_artifact()
195
+
196
+ def _generate_dataset(
197
+ self, data: List[Dict[str, Any]], config: Dict[str, Any]
198
+ ) -> pd.DataFrame:
199
+ """Generate a dataset from evaluation data.
200
+
201
+ Args:
202
+ data (List[Dict[str, Any]]): The data to generate a dataset for.
203
+ config (Dict[str, Any]): The configuration of the task.
204
+
205
+ Returns:
206
+ pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
207
+ """
208
+ ids = [x["doc_id"] for x in data]
209
+ labels = [x["target"] for x in data]
210
+ instance = [""] * len(ids)
211
+ resps = [""] * len(ids)
212
+ filtered_resps = [""] * len(ids)
213
+ model_outputs = {}
214
+
215
+ metrics_list = config["metric_list"]
216
+ metrics = {}
217
+ for metric in metrics_list:
218
+ metric = metric.get("metric")
219
+ if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
220
+ metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
221
+ if metric in ["byte_perplexity", "bits_per_byte"]:
222
+ metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
223
+ else:
224
+ metrics[f"{metric}_words"] = [x[metric][1] for x in data]
225
+ else:
226
+ metrics[metric] = [x[metric] for x in data]
227
+
228
+ if config["output_type"] == "loglikelihood":
229
+ instance = [x["arguments"][0][0] for x in data]
230
+ labels = [x["arguments"][0][1] for x in data]
231
+ resps = [
232
+ f"log probability of continuation is {x['resps'][0][0][0]} "
233
+ + "\n\n"
234
+ + "continuation will {} generated with greedy sampling".format(
235
+ "not be" if not x["resps"][0][0][1] else "be"
236
+ )
237
+ for x in data
238
+ ]
239
+ filtered_resps = [
240
+ f"log probability of continuation is {x['filtered_resps'][0][0]} "
241
+ + "\n\n"
242
+ + "continuation will {} generated with greedy sampling".format(
243
+ "not be" if not x["filtered_resps"][0][1] else "be"
244
+ )
245
+ for x in data
246
+ ]
247
+ elif config["output_type"] == "multiple_choice":
248
+ instance = [x["arguments"][0][0] for x in data]
249
+ choices = [
250
+ "\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
251
+ for x in data
252
+ ]
253
+ resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
254
+ filtered_resps = [
255
+ np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
256
+ ]
257
+ elif config["output_type"] == "loglikelihood_rolling":
258
+ instance = [x["arguments"][0][0] for x in data]
259
+ resps = [x["resps"][0][0] for x in data]
260
+ filtered_resps = [x["filtered_resps"][0] for x in data]
261
+ elif config["output_type"] == "generate_until":
262
+ instance = [x["arguments"][0][0] for x in data]
263
+ resps = [x["resps"][0][0] for x in data]
264
+ filtered_resps = [x["filtered_resps"][0] for x in data]
265
+
266
+ model_outputs["raw_predictions"] = resps
267
+ model_outputs["filtered_predictions"] = filtered_resps
268
+
269
+ df_data = {
270
+ "id": ids,
271
+ "data": instance,
272
+ }
273
+ if config["output_type"] == "multiple_choice":
274
+ df_data["choices"] = choices
275
+
276
+ tmp_data = {
277
+ "input_len": [len(x) for x in instance],
278
+ "labels": labels,
279
+ "output_type": config["output_type"],
280
+ }
281
+ df_data.update(tmp_data)
282
+ df_data.update(model_outputs)
283
+ df_data.update(metrics)
284
+
285
+ return pd.DataFrame(df_data)
286
+
287
+ def _log_samples_as_artifact(
288
+ self, data: List[Dict[str, Any]], task_name: str
289
+ ) -> None:
290
+ import wandb
291
+
292
+ # log the samples as an artifact
293
+ dumped = json.dumps(
294
+ data,
295
+ indent=2,
296
+ default=_handle_non_serializable,
297
+ ensure_ascii=False,
298
+ )
299
+ artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
300
+ with artifact.new_file(
301
+ f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
302
+ ) as f:
303
+ f.write(dumped)
304
+ self.run.log_artifact(artifact)
305
+ # artifact.wait()
306
+
307
+ def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
308
+ """Log evaluation samples to W&B.
309
+
310
+ Args:
311
+ samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
312
+ """
313
+ task_names: List[str] = [
314
+ x for x in self.task_names if x not in self.group_names
315
+ ]
316
+
317
+ ungrouped_tasks = []
318
+ tasks_by_groups = {}
319
+
320
+ for task_name in task_names:
321
+ group_names = self.task_configs[task_name].get("group", None)
322
+ if group_names:
323
+ if isinstance(group_names, str):
324
+ group_names = [group_names]
325
+
326
+ for group_name in group_names:
327
+ if not tasks_by_groups.get(group_name):
328
+ tasks_by_groups[group_name] = [task_name]
329
+ else:
330
+ tasks_by_groups[group_name].append(task_name)
331
+ else:
332
+ ungrouped_tasks.append(task_name)
333
+
334
+ for task_name in ungrouped_tasks:
335
+ eval_preds = samples[task_name]
336
+
337
+ # log the samples as a W&B Table
338
+ df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
339
+ self.run.log({f"{task_name}_eval_results": df}, step=self.step)
340
+
341
+ # log the samples as a json file as W&B Artifact
342
+ self._log_samples_as_artifact(eval_preds, task_name)
343
+
344
+ for group, grouped_tasks in tasks_by_groups.items():
345
+ grouped_df = pd.DataFrame()
346
+ for task_name in grouped_tasks:
347
+ eval_preds = samples[task_name]
348
+ df = self._generate_dataset(
349
+ eval_preds, self.task_configs.get(task_name)
350
+ )
351
+ df["group"] = group
352
+ df["task"] = task_name
353
+ grouped_df = pd.concat([grouped_df, df], ignore_index=True)
354
+
355
+ # log the samples as a json file as W&B Artifact
356
+ self._log_samples_as_artifact(eval_preds, task_name)
357
+
358
+ self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ diffllm,
3
+ huggingface,
4
+ )
5
+
6
+
7
+ # TODO: implement __all__
8
+
9
+
10
+ try:
11
+ # enable hf hub transfer if available
12
+ import hf_transfer # type: ignore # noqa
13
+ import huggingface_hub.constants # type: ignore
14
+
15
+ huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
16
+ except ImportError:
17
+ pass
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/diffllm.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gc
3
+ import random
4
+ import json
5
+ import os
6
+ import time
7
+ from datetime import timedelta
8
+ from typing import List, Optional, Tuple, Type, TypeVar, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import transformers
13
+ from accelerate import (
14
+ Accelerator,
15
+ InitProcessGroupKwargs,
16
+ find_executable_batch_size,
17
+ )
18
+ from datasets import Dataset
19
+ from packaging import version
20
+ from tqdm import tqdm
21
+
22
+ from lm_eval import utils
23
+ from lm_eval.api.instance import Instance
24
+ from lm_eval.api.model import LM
25
+ from lm_eval.api.registry import register_model
26
+ from lm_eval.models.utils import Collator, get_dtype
27
+
28
+ eval_logger = logging.getLogger(__name__)
29
+ T = TypeVar("T", bound="LM")
30
+
31
+
32
+ def empty_cache_by_memory(threshold_gb=70):
33
+ """
34
+ Empty CUDA cache if allocated memory exceeds threshold
35
+ Args:
36
+ threshold_gb: Memory threshold in GB
37
+ """
38
+ if torch.cuda.is_available():
39
+ # Get current memory allocated
40
+ allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB
41
+
42
+ if allocated > threshold_gb:
43
+ # Clear cache
44
+ gc.collect()
45
+ torch.cuda.empty_cache()
46
+ print(f"Cache cleared. Memory freed: {allocated:.2f} GB")
47
+
48
+ @register_model("diffllm")
49
+ class DiffLLM(LM):
50
+ def __init__(
51
+ self,
52
+ pretrained: Union[str, transformers.PreTrainedModel],
53
+ batch_size: Optional[Union[int, str]] = 1,
54
+ device: Optional[str] = "cuda",
55
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
56
+ max_prompt_len: Optional[int] = 1024,
57
+ max_new_tokens: Optional[int] = 128,
58
+ nll_type: Optional[str] = "mc",
59
+ log_type: Optional[str] = "ftb",
60
+ classifier_free_guidance: Optional[float] = 1.0,
61
+ pad_to_max_len: Optional[bool] = False,
62
+ sampling_eps: Optional[float] = 1e-3,
63
+ diffusion_steps: Optional[int] = 32,
64
+ trust_remote_code: Optional[bool] = True,
65
+ parallelize: Optional[bool] = False,
66
+ autogptq: Optional[Union[bool, str]] = False,
67
+ **kwargs,
68
+ ) -> None:
69
+ super().__init__()
70
+
71
+ # prepare for parallelism
72
+ assert isinstance(device, str)
73
+ assert isinstance(pretrained, str)
74
+ assert isinstance(batch_size, (int, str))
75
+
76
+ gpus = torch.cuda.device_count()
77
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
78
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
79
+
80
+ self.accelerator = accelerator
81
+
82
+ if "npu" in accelerator.device.type:
83
+ gpus = torch.npu.device_count()
84
+
85
+ # using one process with no model parallelism
86
+ if not (parallelize or accelerator.num_processes > 1):
87
+ # use user-passed device
88
+ device_list = set(
89
+ ["cuda", "cpu"]
90
+ + [f"cuda:{i}" for i in range(gpus)]
91
+ + ["mps", "mps:0"]
92
+ + [f"npu:{i}" for i in range(gpus)]
93
+ )
94
+ if device and device in device_list:
95
+ self._device = torch.device(device)
96
+ eval_logger.info(f"Using device '{device}'")
97
+ if device in ("mps", "mps:0") and version.parse(
98
+ torch.__version__
99
+ ) < version.parse("2.1"):
100
+ raise RuntimeError(
101
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
102
+ )
103
+ else:
104
+ eval_logger.info("Device not specified")
105
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
106
+ self._device = (
107
+ torch.device("cuda")
108
+ if torch.cuda.is_available()
109
+ else torch.device("cpu")
110
+ )
111
+ else:
112
+ if device != "cuda":
113
+ eval_logger.info(
114
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
115
+ )
116
+ self._device = self.accelerator.device
117
+
118
+ self.batch_size_per_gpu = batch_size
119
+ if isinstance(batch_size, str):
120
+ self.batch_size_per_gpu = int(batch_size)
121
+ self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
122
+
123
+ if isinstance(pretrained, str):
124
+ if gpus >= 1 or str(self.device) == "mps":
125
+ if not (parallelize or autogptq or (hasattr(self, "accelerator") and self.accelerator.num_processes > 1)):
126
+ try:
127
+ self.model.to(self.device)
128
+ except ValueError:
129
+ eval_logger.debug(
130
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
131
+ )
132
+ if gpus > 1:
133
+ if self.accelerator.num_processes > 1:
134
+ self._device = torch.device(f"{accelerator.device}")
135
+ self._rank = self.accelerator.local_process_index
136
+ self._world_size = self.accelerator.num_processes
137
+ else:
138
+ self._rank = 0
139
+ self._world_size = 1
140
+ else:
141
+ self._rank = 0
142
+ self._world_size = 1
143
+ else:
144
+ eval_logger.warning(
145
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
146
+ )
147
+ self._rank = 0
148
+ self._world_size = 1
149
+
150
+ self.max_prompt_len = max_prompt_len
151
+ self.max_new_tokens = max_new_tokens
152
+ self.diffusion_steps = diffusion_steps
153
+ self.temperature = kwargs.get("temperature", 0.7)
154
+ self.top_p = kwargs.get("top_p", 0.95)
155
+ self.alg = kwargs.get("alg", "entropy")
156
+ self.alg_temp = kwargs.get("alg_temp", 0.0)
157
+ self.top_k = kwargs.get("top_k", None)
158
+
159
+ self.nll_type = nll_type
160
+ self.log_type = log_type
161
+ self.classifier_free_guidance = classifier_free_guidance
162
+ self.pad_to_max_len = pad_to_max_len
163
+ self.sampling_eps = sampling_eps
164
+
165
+ self.mask_id = 151666
166
+ self.eos_id = 151643
167
+
168
+ raw_use_hts = kwargs.get("use_hts", False)
169
+ if isinstance(raw_use_hts, str):
170
+ self.use_hts = raw_use_hts.lower() == "true"
171
+ else:
172
+ self.use_hts = bool(raw_use_hts)
173
+
174
+ self.realtime_output = kwargs.get("realtime_output", "eval_results.jsonl")
175
+
176
+ if self.use_hts:
177
+ from .hts_sampler import HTSSampler
178
+ self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
179
+ eval_logger.info(f"Rank {self.rank}: HTS Sampler initialized for Dream.")
180
+
181
+ @property
182
+ def batch_size(self):
183
+ return self.batch_size_per_gpu
184
+
185
+ @property
186
+ def device(self):
187
+ return self._device
188
+
189
+ @property
190
+ def rank(self):
191
+ return self._rank
192
+
193
+ @property
194
+ def world_size(self):
195
+ return self._world_size
196
+
197
+ def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
198
+ self.model = (
199
+ transformers.AutoModel.from_pretrained(
200
+ pretrained,
201
+ torch_dtype=get_dtype(dtype),
202
+ trust_remote_code=trust_remote_code,
203
+ )
204
+ .eval()
205
+ ).to(self.device)
206
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
207
+ pretrained, trust_remote_code=trust_remote_code
208
+ )
209
+
210
+ def tok_decode(self, tokens, skip_special_tokens=True):
211
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
212
+
213
+ def tok_encode(self, text, add_special_tokens=True):
214
+ return self.tokenizer(
215
+ text, return_tensors="pt", add_special_tokens=add_special_tokens
216
+ ).input_ids
217
+
218
+ @classmethod
219
+ def create_from_arg_string(
220
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
221
+ ) -> T:
222
+ additional_config = {} if additional_config is None else additional_config
223
+ args = utils.simple_parse_args_string(arg_string)
224
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
225
+ return cls(**args, **args2)
226
+
227
+ def apply_chat_template(
228
+ self, chat_history, add_generation_prompt: bool = True
229
+ ) -> str:
230
+ chat_templated = self.tokenizer.apply_chat_template(
231
+ chat_history,
232
+ tokenize=False,
233
+ add_generation_prompt=add_generation_prompt,
234
+ continue_final_message=not add_generation_prompt,
235
+ )
236
+ return chat_templated
237
+
238
+ @property
239
+ def tokenizer_name(self) -> str:
240
+ return self.tokenizer.name_or_path.replace("/", "__")
241
+
242
+ def _generate_batch(self, prompts: List[str], gen_kwargs: dict = None) -> Tuple[List[str], List[dict]]:
243
+ raw_val = gen_kwargs.get("use_hts", self.use_hts)
244
+ use_hts_now = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val
245
+
246
+ all_stats = []
247
+ if not use_hts_now:
248
+ prompt_ids = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").input_ids
249
+ prompt_ids = prompt_ids[:, -self.max_prompt_len:]
250
+ attn_mask = prompt_ids.ne(self.tokenizer.pad_token_id).to(self.device)
251
+ prompt_ids = prompt_ids.to(device=self.device)
252
+
253
+ generation_ids = self.model.diffusion_generate(
254
+ prompt_ids,
255
+ attention_mask=attn_mask,
256
+ max_new_tokens=self.max_new_tokens,
257
+ output_history=False,
258
+ return_dict_in_generate=True,
259
+ steps=self.diffusion_steps,
260
+ temperature=self.temperature,
261
+ top_p=self.top_p,
262
+ top_k=self.top_k,
263
+ alg=self.alg,
264
+ alg_temp=self.alg_temp,
265
+ )
266
+ responses = [
267
+ self.tokenizer.decode(g[len(p) :].tolist()).split(self.tokenizer.eos_token)[0]
268
+ for p, g in zip(prompt_ids, generation_ids.sequences)
269
+ ]
270
+ all_stats = [{} for _ in responses]
271
+ return responses, all_stats
272
+ else:
273
+ if not hasattr(self, "hts_sampler"):
274
+ from .hts_sampler import HTSSampler
275
+ self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device)
276
+
277
+ results = []
278
+ for prompt in prompts:
279
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
280
+
281
+ final_codes, stats = self.hts_sampler.generate_hts(
282
+ prompt_text=prompt,
283
+ input_ids=input_ids,
284
+ initial_N=int(gen_kwargs.get("initial_N", 4)),
285
+ final_K=int(gen_kwargs.get("final_K", 1)),
286
+ hts_survivor_k=int(gen_kwargs.get("hts_survivor_k", 4)),
287
+ reward_mode=gen_kwargs.get("reward_mode", "svf"),
288
+ task_type=gen_kwargs.get("task_type", "code"),
289
+ steps=self.diffusion_steps,
290
+ gen_length=self.max_new_tokens,
291
+ temperature=float(gen_kwargs.get("temperature", self.temperature)),
292
+ top_p=float(gen_kwargs.get("top_p", self.top_p)),
293
+ top_k=gen_kwargs.get("top_k", self.top_k),
294
+ until=gen_kwargs.get("until", []),
295
+ hts_mode=True,
296
+ mask_id=self.mask_id,
297
+ eos_id=self.eos_id
298
+ )
299
+
300
+ results.append(final_codes[0])
301
+ all_stats.append(stats)
302
+ return results, all_stats
303
+
304
+ def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
305
+ res = []
306
+
307
+ gen_kwargs_first = requests[0].args[1]
308
+ actual_output_path = gen_kwargs_first.get("realtime_output", self.realtime_output)
309
+
310
+ raw_val = gen_kwargs_first.get("use_hts", self.use_hts)
311
+ self.use_hts = str(raw_val).lower() == "true" if not isinstance(raw_val, bool) else raw_val
312
+
313
+ rank_tmp_file = actual_output_path.replace(".jsonl", f"_rank{self.rank}.tmp")
314
+
315
+ output_dir = os.path.dirname(rank_tmp_file)
316
+ if output_dir and not os.path.exists(output_dir):
317
+ os.makedirs(output_dir, exist_ok=True)
318
+
319
+ pbar = tqdm(
320
+ total=len(requests),
321
+ disable=(disable_tqdm or (self.rank != 0)),
322
+ desc="Running generate_until",
323
+ )
324
+
325
+ for batch_idx in range(0, len(requests), self.batch_size_per_gpu):
326
+ batch_requests = requests[batch_idx : batch_idx + self.batch_size_per_gpu]
327
+ contexts, task_gen_args = zip(*[req.arguments for req in batch_requests])
328
+
329
+ responses, stats_list = self._generate_batch(contexts, gen_kwargs=task_gen_args[0])
330
+
331
+ for i, r in enumerate(responses):
332
+ r = r.replace("```python", "").replace("```", "")
333
+
334
+ for s in task_gen_args[0].get('until', []):
335
+ r = r.split(s)[0]
336
+
337
+ target_val = getattr(batch_requests[i], "target", None)
338
+ if target_val is None or target_val == "N/A":
339
+ target_val = batch_requests[i].doc.get("answer", batch_requests[i].doc.get("solution", "N/A"))
340
+
341
+ save_data = {
342
+ "doc": batch_requests[i].doc,
343
+ "target": target_val,
344
+ "prompt": contexts[i],
345
+ "response": r,
346
+ }
347
+
348
+ if self.use_hts:
349
+ save_data.update(stats_list[i])
350
+
351
+ with open(rank_tmp_file, "a", encoding="utf-8") as f:
352
+ f.write(json.dumps(save_data, ensure_ascii=False) + "\n")
353
+ f.flush()
354
+
355
+ responses[i] = r
356
+
357
+ if self.rank == 0 and batch_idx == 0:
358
+ print(f"Sample Response:\n{responses[0]}\n")
359
+
360
+ res.extend(responses)
361
+ pbar.update(len(batch_requests))
362
+
363
+ pbar.close()
364
+
365
+ self.accelerator.wait_for_everyone()
366
+
367
+ if self.rank == 0:
368
+ eval_logger.info(f"Merging rank files into {actual_output_path}...")
369
+ with open(actual_output_path, "w", encoding="utf-8") as final_f:
370
+ for r in range(self.world_size):
371
+ temp_f = actual_output_path.replace(".jsonl", f"_rank{r}.tmp")
372
+ if os.path.exists(temp_f):
373
+ with open(temp_f, "r", encoding="utf-8") as tf:
374
+ for line in tf:
375
+ final_f.write(line)
376
+ os.remove(temp_f)
377
+ eval_logger.info("Merge completed.")
378
+
379
+ return res
380
+
381
+ def _forward_process(self, batch):
382
+ b, l = batch.shape
383
+ u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
384
+ indices = torch.arange(b, device=batch.device).float()
385
+ t = (u0 + indices / b) % 1
386
+ p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
387
+ p_mask = p_mask[:, None].repeat(1, l)
388
+ mask_indices = torch.rand((b, l), device=batch.device) < p_mask
389
+ mask_indices[:, 0] = False
390
+ mask_indices[:, -1] = False
391
+ noisy_batch = torch.where(mask_indices, self.mask_id, batch)
392
+ return noisy_batch, p_mask
393
+
394
+ @torch.no_grad()
395
+ def get_logits(self, batch, prompt_index):
396
+ if self.classifier_free_guidance > 1.:
397
+ assert len(prompt_index) == batch.shape[1]
398
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
399
+ un_batch = batch.clone()
400
+ un_batch[prompt_index] = self.mask_id
401
+ batch = torch.cat([batch, un_batch])
402
+
403
+ if self.pad_to_max_len:
404
+ raise NotImplementedError
405
+ else:
406
+ input = batch
407
+
408
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
409
+ logits = self.model(input, 'full').logits
410
+
411
+ if self.classifier_free_guidance > 1.:
412
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
413
+ logits = un_logits + self.classifier_free_guidance * (logits - un_logits)
414
+ return logits[:, :batch.shape[1]]
415
+
416
+ @torch.no_grad()
417
+ def _eval_target_nll_mc(self, prefix, target):
418
+ if prefix is None:
419
+ seq = target[None, :]
420
+ else:
421
+ seq = torch.concatenate([prefix, target])[None, :]
422
+ seq = seq.repeat((self.batch_size, 1)).to(self.device)
423
+
424
+ if self.log_type == 'ftb':
425
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
426
+ else:
427
+ prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
428
+
429
+ loss_acc = []
430
+ mc_num = self.diffusion_steps
431
+ for _ in range(max(mc_num // self.batch_size, 1)):
432
+ perturbed_seq = seq.clone()
433
+ perturbed_seq_, p_mask = self._forward_process(seq)
434
+ if self.log_type == 'ftb':
435
+ perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
436
+ elif self.log_type == 'btf':
437
+ perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
438
+ elif self.log_type == 'union':
439
+ perturbed_seq = perturbed_seq_
440
+ else:
441
+ raise NotImplementedError(self.log_type)
442
+
443
+ mask_indices = perturbed_seq == self.mask_id
444
+
445
+ logits = self.get_logits(perturbed_seq, prompt_index)
446
+
447
+ loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
448
+ loss = loss.sum() / self.batch_size
449
+ loss_acc.append(loss.item())
450
+ del logits, loss, perturbed_seq, perturbed_seq_, p_mask, mask_indices
451
+ empty_cache_by_memory(threshold_gb=70)
452
+
453
+ return sum(loss_acc) / len(loss_acc)
454
+
455
+ @torch.no_grad()
456
+ def _eval_target_nll_ar(self, prefix, target):
457
+ prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
458
+ assert self.log_type in ['ftb', 'btf']
459
+ assert self.nll_type in ['ar_ftb', 'ar_btf']
460
+
461
+ if self.log_type == 'ftb':
462
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
463
+ else:
464
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
465
+
466
+ if self.log_type == 'ftb':
467
+ perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
468
+ else:
469
+ perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
470
+
471
+ mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
472
+ if self.nll_type == 'ar_ftb':
473
+ mask_index = torch.triu(mask_index)
474
+ else:
475
+ mask_index = torch.tril(mask_index)
476
+ perturbed_[mask_index] = self.mask_id
477
+ if self.log_type == 'ftb':
478
+ perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
479
+ else:
480
+ perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
481
+
482
+ logits_ = []
483
+ num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
484
+ for i in range(num):
485
+ end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
486
+ perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
487
+ perturbed_seq_ = perturbed_seq_.to(self.device)
488
+ if len(perturbed_seq_.shape) == 1:
489
+ perturbed_seq_ = perturbed_seq_.unsqueeze(0)
490
+ logits = self.get_logits(perturbed_seq_, prompt_index)
491
+ logits_.append(logits.cpu())
492
+ logits = torch.cat(logits_, dim=0)
493
+
494
+ temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
495
+ if self.nll_type == 'ar_ftb':
496
+ temp_index = torch.triu(temp_index, diagonal=1)
497
+ else:
498
+ temp_index = torch.tril(temp_index, diagonal=-1)
499
+ mask_index[temp_index] = False
500
+ if self.log_type == 'ftb':
501
+ logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
502
+ else:
503
+ logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
504
+
505
+ if self.log_type == 'ftb':
506
+ loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
507
+ else:
508
+ loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
509
+ return loss
510
+
511
+ def _encode_pair(self, context, continuation):
512
+ n_spaces = len(context) - len(context.rstrip())
513
+ if n_spaces > 0:
514
+ continuation = context[-n_spaces:] + continuation
515
+ context = context[:-n_spaces]
516
+
517
+ whole_enc = self.tokenizer.encode(context + continuation) + [
518
+ self.tokenizer.eos_token_id
519
+ ]
520
+ context_enc = self.tokenizer.encode(context)
521
+
522
+ context_enc_len = len(context_enc)
523
+ continuation_enc = whole_enc[context_enc_len:]
524
+
525
+ return context_enc, continuation_enc
526
+
527
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
528
+ def _tokenize(e):
529
+ prefix, target = self._encode_pair(e["prefix"], e["target"])
530
+ return {
531
+ "prefix_text": e["prefix"],
532
+ "target_text": e["target"],
533
+ "prefix": prefix,
534
+ "target": target,
535
+ }
536
+
537
+ ds = []
538
+ ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
539
+ ds = Dataset.from_list(ds)
540
+ ds = ds.map(_tokenize)
541
+ ds = ds.with_format("torch")
542
+
543
+ out = []
544
+ with torch.no_grad():
545
+ for elem in tqdm(ds, desc="Computing likelihood..."):
546
+ prefix = elem["prefix"]
547
+ target = elem["target"]
548
+
549
+ if self.nll_type == 'mc':
550
+ ll = -self._eval_target_nll_mc(prefix, target)
551
+ if self.log_type == 'union':
552
+ ll = ll / (len(target) + len(prefix))
553
+ elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
554
+ ll = -self._eval_target_nll_ar(prefix, target)
555
+ else:
556
+ raise NotImplementedError(self.nll_type)
557
+
558
+ is_target_greedy_dec = False
559
+ out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
560
+ return out
561
+
562
+ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
563
+ raise NotImplementedError
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/dummy.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from tqdm import tqdm
4
+
5
+ from lm_eval.api.model import LM
6
+ from lm_eval.api.registry import register_model
7
+
8
+
9
+ @register_model("dummy")
10
+ class DummyLM(LM):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ @classmethod
15
+ def create_from_arg_string(cls, arg_string, additional_config=None):
16
+ return cls()
17
+
18
+ def loglikelihood(self, requests, disable_tqdm: bool = False):
19
+ res = []
20
+
21
+ for _ in tqdm(requests, disable=disable_tqdm):
22
+ res.append((-random.random(), False))
23
+
24
+ return res
25
+
26
+ def generate_until(self, requests, disable_tqdm: bool = False):
27
+ res = []
28
+
29
+ for request in tqdm(requests, disable=disable_tqdm):
30
+ res.append("lol")
31
+ assert request.arguments[0].strip() != ""
32
+
33
+ return res
34
+
35
+ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
36
+ res = []
37
+
38
+ for _ in tqdm(requests, disable=disable_tqdm):
39
+ res.append(-random.random())
40
+
41
+ return res
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/hts_sampler.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from .verifier import CodeVerifier
5
+ import logging
6
+ import re
7
+ import math
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class HTSSampler:
12
+ def __init__(self, model, tokenizer, device="cuda"):
13
+ self.model = model
14
+ self.tokenizer = tokenizer
15
+ self.device = device
16
+ self.verifier = CodeVerifier(model, tokenizer, device)
17
+
18
+ def _get_num_transfer_tokens(self, block_length, steps):
19
+ if steps == 0: return torch.tensor([], dtype=torch.int64)
20
+ base = block_length // steps
21
+ remainder = block_length % steps
22
+ num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64)
23
+ num_transfer_tokens[:remainder] += 1
24
+ return num_transfer_tokens
25
+
26
+ def _sample_with_temperature(self, logits, temperature, top_k, top_p):
27
+ logits = logits.to(torch.float32)
28
+ orig_probs = torch.softmax(logits, dim=-1)
29
+ x0_p, _ = torch.max(orig_probs, dim=-1)
30
+
31
+ if temperature > 0.0:
32
+ noise = torch.rand_like(logits, dtype=torch.float32)
33
+ gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10)
34
+ logits = logits / temperature + gumbel_noise
35
+
36
+ if top_k is not None and top_k > 0:
37
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
38
+ logits[indices_to_remove] = -float('Inf')
39
+
40
+ x0 = torch.argmax(logits, dim=-1)
41
+ return x0, x0_p
42
+
43
+ def _safe_scalar(self, val):
44
+ if isinstance(val, torch.Tensor):
45
+ if val.numel() > 1: return val.mean().item()
46
+ return val.item()
47
+ return float(val)
48
+
49
+ def _analyze_structure(self, text, task_type="code"):
50
+ score = 0.0
51
+ stripped = text.strip()
52
+ if task_type == "code":
53
+ if len(stripped) < 5: return -0.1
54
+ keywords = ["return", "print", "yield", "lambda", "class ", "def "]
55
+ if any(k in stripped for k in keywords): score += 0.05
56
+ if ":" in stripped: score += 0.02
57
+ if " " in text: score += 0.03
58
+ elif task_type == "math":
59
+ if "\\boxed{" in stripped: score += 0.1
60
+ if "The answer is" in stripped: score += 0.05
61
+ return score
62
+
63
+ def _chunked_forward(self, x, chunk_size=32, slice_indices=None):
64
+ total_batch = x.shape[0]
65
+ logits_list = []
66
+ for i in range(0, total_batch, chunk_size):
67
+ end_idx = min(i + chunk_size, total_batch)
68
+ sub_x = x[i:end_idx]
69
+ with torch.no_grad():
70
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
71
+ outputs = self.model(sub_x, 'full')
72
+ sub_logits = outputs.logits
73
+ sub_logits = torch.cat([sub_logits[:, :1, :], sub_logits[:, :-1, :]], dim=1)
74
+ if slice_indices is not None:
75
+ s_start, s_end = slice_indices
76
+ sub_logits = sub_logits[:, s_start:s_end, :]
77
+ logits_list.append(sub_logits.detach().clone())
78
+ return torch.cat(logits_list, dim=0)
79
+
80
+ def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id,
81
+ prompt_length, resample_window=6, task_type="code"):
82
+ num_survivors = len(survivor_indices)
83
+ if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone()
84
+
85
+ base_repeat = target_width // num_survivors
86
+ remainder = target_width % num_survivors
87
+ new_x_list, new_conf_list = [], []
88
+
89
+ for i in range(num_survivors):
90
+ count = base_repeat + (1 if i < remainder else 0)
91
+ if count == 0: continue
92
+ survivor_x = x[survivor_indices[i]]
93
+ survivor_conf = conf_scores[survivor_indices[i]]
94
+
95
+ new_x_list.append(survivor_x.unsqueeze(0))
96
+ new_conf_list.append(survivor_conf.unsqueeze(0))
97
+
98
+ if count > 1:
99
+ gen_part = survivor_x[prompt_length:]
100
+ gen_conf = survivor_conf[prompt_length:]
101
+ non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0]
102
+ for _ in range(count - 1):
103
+ perturbed_x = survivor_x.clone()
104
+ perturbed_conf = survivor_conf.clone()
105
+ if len(non_mask_indices) > 0:
106
+ pool_size = min(resample_window * 2, len(non_mask_indices))
107
+ current_token_confs = gen_conf[non_mask_indices]
108
+ _, candidate_pool = torch.topk(current_token_confs, k=pool_size, largest=False)
109
+
110
+ num_to_perturb = min(resample_window, pool_size)
111
+ rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb]
112
+ selected_sub_indices = candidate_pool[rand_indices]
113
+
114
+ target_idx_in_x = prompt_length + non_mask_indices[selected_sub_indices]
115
+ perturbed_x[target_idx_in_x] = mask_id
116
+ perturbed_conf[target_idx_in_x] = 0.0
117
+ new_x_list.append(perturbed_x.unsqueeze(0))
118
+ new_conf_list.append(perturbed_conf.unsqueeze(0))
119
+
120
+ return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0)
121
+
122
+ @torch.no_grad()
123
+ def generate_hts(self, prompt_text, input_ids, problem_data=None,
124
+ initial_N=1, final_K=1, survivor_K=None,
125
+ prune_step_pct=0.0, reward_mode="confidence",
126
+ temperature=0.7, block_length=32, steps=64, gen_length=1024,
127
+ top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9,
128
+ eos_id=151643, mask_id=151666,
129
+ hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.2,
130
+ hts_survivor_k=4, task_type="code", until=None, pruning_interval=20):
131
+
132
+ input_ids = input_ids.to(self.device)
133
+ prompt_length = input_ids.shape[1]
134
+ total_length = prompt_length + gen_length
135
+
136
+ x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device)
137
+ x[:, :prompt_length] = input_ids.repeat(initial_N, 1)
138
+ conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device)
139
+ conf_scores[:, :prompt_length] = 1.0
140
+
141
+ schedule = self._get_num_transfer_tokens(gen_length, steps)
142
+ current_bsz = initial_N
143
+ schedule_map = {}
144
+ ts_start, tr_end = 0, 0
145
+
146
+ if hts_mode:
147
+ ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct)
148
+ else:
149
+ final_K_list = [final_K] if not isinstance(final_K, list) else final_K
150
+ prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct
151
+ for pct, width in zip(prune_pct_list, final_K_list):
152
+ if pct > 0: schedule_map[int(steps * pct)] = width
153
+
154
+ stats = {
155
+ "initial_n": initial_N,
156
+ "final_k": final_K if not isinstance(final_K, list) else final_K[-1],
157
+ "nfe": 0,
158
+ "svf_calls": 0,
159
+ "pruning_history": [],
160
+ "entropy_history": [],
161
+ "final_scores": []
162
+ }
163
+
164
+ next_allowed_pruning_step = ts_start
165
+
166
+ for step in range(steps):
167
+ perform_pruning = False
168
+ num_parents_to_select = hts_survivor_k
169
+
170
+ if hts_mode and ts_start <= step < tr_end and step >= next_allowed_pruning_step:
171
+ target_width = max(stats["final_k"], math.ceil(initial_N * (decay_factor ** -(step - ts_start))))
172
+ if current_bsz > target_width:
173
+ perform_pruning = True
174
+ elif not hts_mode and step in schedule_map:
175
+ target_width = schedule_map[step]
176
+ num_parents_to_select = target_width
177
+ if current_bsz > target_width:
178
+ perform_pruning = True
179
+
180
+ if perform_pruning:
181
+ stats["svf_calls"] += current_bsz
182
+ full_logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length))
183
+ rough_ids = torch.argmax(full_logits, dim=-1)
184
+ rough_codes = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True)
185
+
186
+ candidates = []
187
+ for i in range(current_bsz):
188
+ s = self._safe_scalar(self.verifier.get_reward(prompt_text, rough_codes[i], mode=reward_mode, current_logits=full_logits[i], task_type=task_type))
189
+ s += self._analyze_structure(rough_codes[i], task_type=task_type)
190
+ clean_text = rough_codes[i].strip().replace(" ", "").replace("\n", "")
191
+ content_key = hash(clean_text[:150] + clean_text[-150:]) if clean_text else i
192
+ candidates.append({'score': s, 'idx': i, 'key': content_key})
193
+
194
+ stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]})
195
+ candidates.sort(key=lambda c: c['score'], reverse=True)
196
+
197
+ selected_indices, seen_keys = [], set()
198
+ for cand in candidates:
199
+ if len(selected_indices) >= num_parents_to_select: break
200
+ if cand['key'] not in seen_keys:
201
+ selected_indices.append(cand['idx']); seen_keys.add(cand['key'])
202
+ for cand in candidates:
203
+ if len(selected_indices) >= num_parents_to_select: break
204
+ if cand['idx'] not in selected_indices: selected_indices.append(cand['idx'])
205
+
206
+ top_indices = torch.tensor(selected_indices, device=self.device)
207
+ x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type)
208
+
209
+ current_bsz = target_width
210
+ next_allowed_pruning_step = step + pruning_interval
211
+
212
+ active_mask = (x[:current_bsz, prompt_length:] == mask_id)
213
+
214
+ stats["nfe"] += current_bsz
215
+ logits = self._chunked_forward(x[:current_bsz, :], slice_indices=(prompt_length, total_length))
216
+
217
+ with torch.no_grad():
218
+ probs = torch.softmax(logits.float(), dim=-1)
219
+ entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean().item()
220
+ stats["entropy_history"].append(entropy)
221
+
222
+ x0, x0_p = self._sample_with_temperature(logits, temperature, top_k, top_p)
223
+ num_transfer = schedule[step].item()
224
+
225
+ confidence = torch.where(active_mask, x0_p, -torch.inf)
226
+ transfer_idx = torch.zeros_like(x0, dtype=torch.bool)
227
+
228
+ for b in range(current_bsz):
229
+ k = min(num_transfer, active_mask[b].sum().item())
230
+ if k <= 0: continue
231
+ high_conf_mask = (confidence[b] > threshold)
232
+ if high_conf_mask.sum() >= k:
233
+ transfer_idx[b] = high_conf_mask
234
+ else:
235
+ _, topk_ids = torch.topk(confidence[b], k=k)
236
+ transfer_idx[b, topk_ids] = True
237
+
238
+ if transfer_idx.any():
239
+ x[:current_bsz, prompt_length:][transfer_idx] = x0[transfer_idx]
240
+ conf_scores[:current_bsz, prompt_length:][transfer_idx] = x0_p[transfer_idx]
241
+
242
+ final_codes = self.tokenizer.batch_decode(x[:current_bsz, prompt_length:], skip_special_tokens=True)
243
+ final_candidates = []
244
+ for i, code in enumerate(final_codes):
245
+ txt = code.split(self.tokenizer.eos_token)[0]
246
+ if until:
247
+ for term in until:
248
+ if term in txt: txt = txt.split(term)[0]
249
+ s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type))
250
+ final_candidates.append({'resp': txt, 'score': s})
251
+
252
+ final_candidates.sort(key=lambda c: c['score'], reverse=True)
253
+ stats["final_scores"] = [c['score'] for c in final_candidates]
254
+ stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)]
255
+
256
+ return [c['resp'] for c in final_candidates], stats
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/huggingface.py ADDED
@@ -0,0 +1,1459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from datetime import timedelta
5
+ from pathlib import Path
6
+ from typing import Dict, List, Literal, Optional, Tuple, Union
7
+
8
+ import jinja2
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import transformers
12
+ from accelerate import (
13
+ Accelerator,
14
+ InitProcessGroupKwargs,
15
+ find_executable_batch_size,
16
+ )
17
+ from accelerate.utils import get_max_memory
18
+ from huggingface_hub import HfApi
19
+ from packaging import version
20
+ from peft import PeftModel
21
+ from peft import __version__ as PEFT_VERSION
22
+ from tqdm import tqdm
23
+ from transformers.models.auto.modeling_auto import (
24
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
25
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
26
+ )
27
+
28
+ from lm_eval import utils
29
+ from lm_eval.api.instance import Instance
30
+ from lm_eval.api.model import TemplateLM
31
+ from lm_eval.api.registry import register_model
32
+ from lm_eval.models.utils import (
33
+ Collator,
34
+ clear_torch_cache,
35
+ configure_pad_token,
36
+ get_dtype,
37
+ handle_stop_sequences,
38
+ pad_and_concat,
39
+ stop_sequences_criteria,
40
+ )
41
+
42
+
43
+ eval_logger = logging.getLogger(__name__)
44
+
45
+
46
+ @register_model("hf-auto", "hf", "huggingface")
47
+ class HFLM(TemplateLM):
48
+ """
49
+ An abstracted Huggingface model class. Enables usage with both models of
50
+ `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
51
+
52
+ Supports data-parallel multi-GPU with HF Accelerate.
53
+ """
54
+
55
+ AUTO_MODEL_CLASS = None
56
+ _DEFAULT_MAX_LENGTH = 2048
57
+
58
+ def __init__(
59
+ self,
60
+ pretrained: Union[str, transformers.PreTrainedModel],
61
+ backend: Literal["default", "causal", "seq2seq"] = "default",
62
+ # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
63
+ revision: Optional[str] = "main",
64
+ subfolder: Optional[str] = None,
65
+ tokenizer: Optional[
66
+ Union[
67
+ str,
68
+ transformers.PreTrainedTokenizer,
69
+ transformers.PreTrainedTokenizerFast,
70
+ ]
71
+ ] = None,
72
+ truncation: Optional[bool] = False,
73
+ logits_cache: bool = True,
74
+ max_length: Optional[int] = None,
75
+ device: Optional[str] = "cuda",
76
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
77
+ batch_size: Optional[Union[int, str]] = 1,
78
+ max_batch_size: Optional[int] = 64,
79
+ trust_remote_code: Optional[bool] = False,
80
+ use_fast_tokenizer: Optional[bool] = True,
81
+ add_bos_token: Optional[bool] = False,
82
+ prefix_token_id: Optional[int] = None,
83
+ # arguments used for splitting a model across GPUs naively.
84
+ # only used if `parallelize=True`.
85
+ parallelize: Optional[bool] = False,
86
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
87
+ max_cpu_memory: Optional[Union[int, str]] = None,
88
+ offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
89
+ # PEFT, delta weights and quantization options
90
+ peft: Optional[str] = None,
91
+ delta: Optional[str] = None,
92
+ autogptq: Optional[Union[bool, str]] = False,
93
+ gptqmodel: Optional[bool] = False,
94
+ gguf_file: Optional[str] = None,
95
+ **kwargs,
96
+ ) -> None:
97
+ super().__init__()
98
+ # optionally: take in an already-initialized transformers.PreTrainedModel
99
+ if not isinstance(pretrained, str):
100
+ eval_logger.warning(
101
+ "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
102
+ )
103
+ assert not parallelize, (
104
+ "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
105
+ )
106
+ self._model = pretrained
107
+ self._device = self._model.device
108
+ self._config = self._model.config
109
+ gpus = 0
110
+
111
+ else:
112
+ assert isinstance(device, str)
113
+ assert isinstance(pretrained, str)
114
+ assert isinstance(batch_size, (int, str))
115
+
116
+ gpus = torch.cuda.device_count()
117
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
118
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
119
+ if accelerator.num_processes > 1:
120
+ self.accelerator = accelerator
121
+
122
+ if "npu" in accelerator.device.type:
123
+ gpus = torch.npu.device_count()
124
+
125
+ # using one process with no model parallelism
126
+ if not (parallelize or accelerator.num_processes > 1):
127
+ # use user-passed device
128
+ device_list = set(
129
+ ["cuda", "cpu"]
130
+ + [f"cuda:{i}" for i in range(gpus)]
131
+ + ["mps", "mps:0"]
132
+ + [f"npu:{i}" for i in range(gpus)]
133
+ )
134
+ if device and device in device_list:
135
+ self._device = torch.device(device)
136
+ eval_logger.info(f"Using device '{device}'")
137
+ if device in ("mps", "mps:0") and version.parse(
138
+ torch.__version__
139
+ ) < version.parse("2.1"):
140
+ raise RuntimeError(
141
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
142
+ )
143
+ else:
144
+ eval_logger.info("Device not specified")
145
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
146
+ self._device = (
147
+ torch.device("cuda")
148
+ if torch.cuda.is_available()
149
+ else torch.device("cpu")
150
+ )
151
+ else: # Parallelism managed by accelerate
152
+ if device != "cuda":
153
+ eval_logger.info(
154
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
155
+ )
156
+ # TODO: include in warning that `load_in_8bit` etc. affect this too
157
+ self._device = (
158
+ self.accelerator.device
159
+ if hasattr(self, "accelerator")
160
+ else torch.device(device)
161
+ )
162
+
163
+ revision = str(revision) # cast to string if not already one
164
+ # TODO: update this to be less of a hack once subfolder is fixed in HF
165
+ revision = revision + ("/" + subfolder if subfolder is not None else "")
166
+
167
+ self._get_config(
168
+ pretrained,
169
+ revision=revision,
170
+ trust_remote_code=trust_remote_code,
171
+ gguf_file=gguf_file,
172
+ )
173
+
174
+ # determine which of 'causal' and 'seq2seq' backends to use for HF models
175
+ self._get_backend(
176
+ config=self.config, backend=backend, trust_remote_code=trust_remote_code
177
+ )
178
+
179
+ # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
180
+ self._create_tokenizer(
181
+ pretrained,
182
+ tokenizer,
183
+ revision=revision,
184
+ trust_remote_code=trust_remote_code,
185
+ use_fast_tokenizer=use_fast_tokenizer,
186
+ gguf_file=gguf_file,
187
+ add_bos_token=add_bos_token,
188
+ )
189
+
190
+ # if we passed `pretrained` as a string, initialize our model now
191
+ if isinstance(pretrained, str):
192
+ self._create_model(
193
+ pretrained=pretrained,
194
+ revision=revision,
195
+ dtype=dtype,
196
+ trust_remote_code=trust_remote_code,
197
+ parallelize=parallelize,
198
+ gpus=gpus,
199
+ max_memory_per_gpu=max_memory_per_gpu,
200
+ max_cpu_memory=max_cpu_memory,
201
+ offload_folder=offload_folder,
202
+ peft=peft,
203
+ delta=delta,
204
+ autogptq=autogptq,
205
+ gptqmodel=gptqmodel,
206
+ gguf_file=gguf_file,
207
+ **kwargs,
208
+ )
209
+
210
+ # access self._model through self.model property outside this method
211
+ if isinstance(self.model, torch.nn.Module):
212
+ self.model.eval()
213
+ self.model.tie_weights()
214
+
215
+ self.truncation = truncation
216
+ self.logits_cache = logits_cache
217
+ self.vocab_size = self.tokenizer.vocab_size
218
+ # select (or create) a pad token to use
219
+ self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
220
+
221
+ self.add_bos_token = add_bos_token
222
+ if "gemma" in getattr(self.config, "model_type", ""):
223
+ self.add_bos_token = True
224
+ eval_logger.info(
225
+ f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
226
+ )
227
+
228
+ self._max_length = max_length
229
+ self.pretrained = pretrained
230
+ self.delta = delta
231
+ self.peft = peft
232
+ self.revision = revision
233
+ self.batch_schedule = 1
234
+ self.batch_sizes = {}
235
+ self.max_batch_size = max_batch_size
236
+
237
+ if str(batch_size).startswith("auto"):
238
+ batch_size = batch_size.split(":")
239
+ self.batch_size_per_gpu = batch_size[0]
240
+ self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
241
+ else:
242
+ self.batch_size_per_gpu = int(batch_size)
243
+
244
+ if isinstance(pretrained, str):
245
+ if gpus >= 1 or str(self.device) == "mps":
246
+ # TODO: can remove this whole snippet except in the mps case, perhaps?
247
+ if not (parallelize or autogptq or hasattr(self, "accelerator")):
248
+ # place model onto device requested manually,
249
+ # if not using HF Accelerate or device_map
250
+ # or any other option that preloads model onto device
251
+ try:
252
+ self.model.to(self.device)
253
+ except ValueError:
254
+ eval_logger.debug(
255
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
256
+ )
257
+ # multigpu data-parallel support when launched with accelerate
258
+ if gpus > 1:
259
+ if accelerator.num_processes > 1:
260
+ if parallelize:
261
+ eval_logger.warning(
262
+ "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
263
+ )
264
+ elif gpus > accelerator.num_processes:
265
+ eval_logger.warning(
266
+ "WARNING: The number of total system GPUs does not match the number of spawned processes. "
267
+ "If you would like to use data parallelism, please launch the script "
268
+ "with 'accelerate launch *script*'. "
269
+ f"Current run will proceed with {accelerator.num_processes} devices."
270
+ )
271
+ if self.accelerator.is_local_main_process:
272
+ eval_logger.info(
273
+ f"Using {gpus} devices with data parallelism"
274
+ )
275
+
276
+ self._device = torch.device(f"{accelerator.device}")
277
+ self.accelerator = accelerator
278
+
279
+ self._rank = self.accelerator.local_process_index
280
+ self._world_size = self.accelerator.num_processes
281
+ else:
282
+ # if we aren't launching via accelerate, ditch
283
+ self._rank = 0
284
+ self._world_size = 1
285
+ else:
286
+ # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
287
+ eval_logger.warning(
288
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
289
+ )
290
+ self._rank = 0
291
+ self._world_size = 1
292
+
293
+ self.custom_prefix_token_id = prefix_token_id
294
+ if prefix_token_id is not None:
295
+ eval_logger.info(
296
+ f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
297
+ )
298
+
299
+ def _get_accelerate_args(
300
+ self,
301
+ parallelize: Optional[bool] = None,
302
+ device_map: Optional[str] = "auto",
303
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
304
+ max_cpu_memory: Optional[Union[int, str]] = None,
305
+ offload_folder: Optional[str] = "./offload",
306
+ gpus: Optional[int] = None,
307
+ ) -> dict:
308
+ """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
309
+ num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
310
+ num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
311
+ if (
312
+ num_machines == 0
313
+ and hasattr(self, "accelerator")
314
+ and self.accelerator is not None
315
+ ):
316
+ eval_logger.info(
317
+ "We are not in a distributed setting for accelerate. Setting model_parallel to False."
318
+ )
319
+ parallelize = False
320
+
321
+ if parallelize is None:
322
+ # If parallelism is unset by the user, we automatically assign model parallelism
323
+ # if enough extra GPUs are available
324
+ max_memory_all_gpus = get_max_memory()
325
+ # We just want gpu, not cpu, max memory
326
+ if "cpu" in max_memory_all_gpus:
327
+ del max_memory_all_gpus["cpu"]
328
+ parallelize = bool(num_local_processes < len(max_memory_all_gpus))
329
+ eval_logger.info(
330
+ f"Setting model parallel to {parallelize} since "
331
+ f"the number of local processes is {num_local_processes} "
332
+ f"and the number of GPUs is {len(max_memory_all_gpus)}"
333
+ )
334
+
335
+ args = {}
336
+ if parallelize: # Model parallelism will be used
337
+ max_memory = {}
338
+ if max_memory_per_gpu is not None: # Using the provided memory requirements
339
+ max_memory_per_gpu_map = {
340
+ device_idx: max_memory_per_gpu for device_idx in range(gpus)
341
+ }
342
+ else: # Estimating the possible memory requirements
343
+ max_memory_all_gpus = get_max_memory()
344
+ if "cpu" in max_memory_all_gpus:
345
+ del max_memory_all_gpus["cpu"]
346
+ if not hasattr(self, "accelerator"):
347
+ max_memory_per_gpu_map = {
348
+ k: v for k, v in max_memory_all_gpus.items()
349
+ }
350
+ else:
351
+ # use only 1 / num_processes of the GPUs if we are running under accelerate launch
352
+ max_memory_per_gpu_map = {
353
+ k: v
354
+ for k, v in max_memory_all_gpus.items()
355
+ if k % num_local_processes
356
+ == (self.accelerator.process_index % num_local_processes)
357
+ }
358
+ args["max_memory"] = max_memory_per_gpu_map
359
+ args["device_map"] = "auto" if device_map is None else device_map
360
+ eval_logger.info(
361
+ f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
362
+ )
363
+
364
+ if max_cpu_memory is not None:
365
+ max_memory["cpu"] = max_cpu_memory
366
+
367
+ args["offload_folder"] = offload_folder
368
+ elif (
369
+ device_map is None
370
+ ): # No model parallelism, we use the default provided device for our model
371
+ if hasattr(self, "accelerator"):
372
+ device_map = {"": f"{self.accelerator.device}"}
373
+ else:
374
+ device_map = {"": str(self.device)}
375
+ args["max_memory"] = None
376
+ args["device_map"] = device_map
377
+ eval_logger.info(
378
+ f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
379
+ )
380
+ else:
381
+ args["max_memory"] = None
382
+ args["device_map"] = None
383
+ eval_logger.info("Model parallel was set to False.")
384
+
385
+ return args
386
+
387
+ @property
388
+ def config(self):
389
+ # return the associated transformers.AutoConfig for the given pretrained model.
390
+ return self._config
391
+
392
+ @property
393
+ def model(self):
394
+ # returns the model, unwrapping it if using Accelerate
395
+ if hasattr(self, "accelerator"):
396
+ return self.accelerator.unwrap_model(self._model)
397
+ else:
398
+ return self._model
399
+
400
+ @property
401
+ def eot_token_id(self):
402
+ # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
403
+ return self.tokenizer.eos_token_id
404
+
405
+ @property
406
+ def prefix_token_id(self):
407
+ # it is used as prefix for loglikelihood
408
+ if self.custom_prefix_token_id is not None:
409
+ return self.custom_prefix_token_id
410
+ if self.tokenizer.bos_token_id is not None:
411
+ return self.tokenizer.bos_token_id
412
+ return self.tokenizer.eos_token_id
413
+
414
+ @property
415
+ def max_length(self):
416
+ if self._max_length: # if max length manually set, return it
417
+ return self._max_length
418
+ seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
419
+ for attr in seqlen_config_attrs:
420
+ if hasattr(self.model.config, attr):
421
+ return getattr(self.model.config, attr)
422
+ if hasattr(self.tokenizer, "model_max_length"):
423
+ if self.tokenizer.model_max_length == 1000000000000000019884624838656:
424
+ return self._DEFAULT_MAX_LENGTH
425
+ return self.tokenizer.model_max_length
426
+ return self._DEFAULT_MAX_LENGTH
427
+
428
+ @property
429
+ def max_gen_toks(self) -> int:
430
+ return 256
431
+
432
+ @property
433
+ def batch_size(self):
434
+ return self.batch_size_per_gpu
435
+
436
+ @property
437
+ def device(self):
438
+ return self._device
439
+
440
+ @property
441
+ def rank(self):
442
+ return self._rank
443
+
444
+ @property
445
+ def world_size(self):
446
+ return self._world_size
447
+
448
+ @property
449
+ def tokenizer_name(self) -> str:
450
+ return self.tokenizer.name_or_path.replace("/", "__")
451
+
452
+ def _get_backend(
453
+ self,
454
+ config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
455
+ backend: Literal["default", "causal", "seq2seq"] = "default",
456
+ trust_remote_code: Optional[bool] = False,
457
+ ) -> None:
458
+ """
459
+ Helper method during initialization.
460
+ Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
461
+ sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
462
+
463
+ **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
464
+ user must set `self.backend` to be either "causal" or "seq2seq" manually!**
465
+ """
466
+
467
+ assert backend in ["default", "causal", "seq2seq"]
468
+
469
+ if backend != "default":
470
+ # if we've settled on non-default backend, use that manually
471
+ if backend == "causal":
472
+ self.backend = backend
473
+ elif backend == "seq2seq":
474
+ self.backend = backend
475
+ eval_logger.info(
476
+ f"Overrode HF model backend type, and using type '{self.backend}'"
477
+ )
478
+ else:
479
+ # determine and use the default HF backend for this model, based on its config + metadata.
480
+ if (
481
+ getattr(config, "model_type")
482
+ in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
483
+ ):
484
+ # first check if model type is listed under seq2seq models, since some
485
+ # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
486
+ # these special cases should be treated as seq2seq models.
487
+ self.backend = "seq2seq"
488
+ eval_logger.debug(f"Using model type '{self.backend}'")
489
+ elif (
490
+ getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
491
+ ):
492
+ self.backend = "causal"
493
+ eval_logger.debug(f"Using model type '{self.backend}'")
494
+ else:
495
+ if not trust_remote_code:
496
+ eval_logger.warning(
497
+ "HF model type is neither marked as CausalLM or Seq2SeqLM. \
498
+ This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
499
+ "Setting backend to causal"
500
+ )
501
+ # if model type is neither in HF transformers causal or seq2seq model registries
502
+ # then we default to assuming AutoModelForCausalLM
503
+ self.backend = "causal"
504
+ eval_logger.info(
505
+ f"Model type cannot be determined. Using default model type '{self.backend}'"
506
+ )
507
+
508
+ if self.AUTO_MODEL_CLASS is None:
509
+ if self.backend == "causal":
510
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
511
+ elif self.backend == "seq2seq":
512
+ self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
513
+
514
+ def _get_config(
515
+ self,
516
+ pretrained: str,
517
+ revision: str = "main",
518
+ trust_remote_code: bool = False,
519
+ gguf_file: Optional[str] = None,
520
+ ) -> None:
521
+ """Return the model config for HuggingFace models"""
522
+ self._config = transformers.AutoConfig.from_pretrained(
523
+ pretrained,
524
+ revision=revision,
525
+ trust_remote_code=trust_remote_code,
526
+ gguf_file=gguf_file,
527
+ )
528
+
529
+ def _create_model(
530
+ self,
531
+ pretrained: str,
532
+ revision: Optional[str] = "main",
533
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
534
+ trust_remote_code: Optional[bool] = False,
535
+ # arguments used for splitting a model across GPUs naively.
536
+ # only used if `parallelize=True`.
537
+ # (accelerate naive PP (device_map) options)
538
+ parallelize: Optional[bool] = False,
539
+ gpus: Optional[int] = None,
540
+ max_memory_per_gpu: Optional[Union[int, str]] = None,
541
+ max_cpu_memory: Optional[Union[int, str]] = None,
542
+ offload_folder: Optional[str] = "./offload",
543
+ # PEFT, delta weights and quantization options
544
+ peft: Optional[str] = None,
545
+ delta: Optional[str] = None,
546
+ autogptq: Optional[Union[bool, str]] = False,
547
+ gptqmodel: Optional[bool] = False,
548
+ gguf_file: Optional[str] = None,
549
+ **kwargs,
550
+ ) -> None:
551
+ """
552
+ Initializes an HF or HF-compatible PreTrainedModel from scratch
553
+ inside HFLM, using the kwargs passed into self.__init__().
554
+
555
+ Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
556
+
557
+ For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
558
+ (such as PyTorch models that are nearly, but not quite, fully mirroring
559
+ HF's public interface relied on in this HFLM class)
560
+ please consider subclassing HFLM and overriding this and other methods as needed.
561
+ """
562
+
563
+ model_kwargs = kwargs if kwargs else {}
564
+
565
+ model_kwargs.update(
566
+ self._get_accelerate_args(
567
+ parallelize=parallelize,
568
+ device_map=kwargs.get("device_map", None),
569
+ max_memory_per_gpu=max_memory_per_gpu,
570
+ max_cpu_memory=max_cpu_memory,
571
+ offload_folder=offload_folder,
572
+ gpus=gpus,
573
+ )
574
+ )
575
+
576
+ if not autogptq and not gptqmodel:
577
+ if model_kwargs.get("load_in_4bit", None):
578
+ assert transformers.__version__ >= "4.30.0", (
579
+ "load_in_4bit requires transformers >= 4.30.0"
580
+ )
581
+ if transformers.__version__ >= "4.30.0":
582
+ if model_kwargs.get("load_in_4bit", None):
583
+ if model_kwargs.get("bnb_4bit_compute_dtype", None):
584
+ model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
585
+ model_kwargs["bnb_4bit_compute_dtype"]
586
+ )
587
+
588
+ self._model = self.AUTO_MODEL_CLASS.from_pretrained(
589
+ pretrained,
590
+ revision=revision,
591
+ torch_dtype=get_dtype(dtype),
592
+ trust_remote_code=trust_remote_code,
593
+ gguf_file=gguf_file,
594
+ **model_kwargs,
595
+ )
596
+ else:
597
+ if autogptq and gptqmodel:
598
+ raise ValueError(
599
+ "Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
600
+ )
601
+
602
+ if autogptq:
603
+ try:
604
+ from auto_gptq import AutoGPTQForCausalLM
605
+ except ModuleNotFoundError as exception:
606
+ raise type(exception)(
607
+ "Tried to load auto_gptq, but auto-gptq is not installed ",
608
+ "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
609
+ )
610
+
611
+ self._model = AutoGPTQForCausalLM.from_quantized(
612
+ pretrained,
613
+ trust_remote_code=trust_remote_code,
614
+ model_basename=None if autogptq is True else Path(autogptq).stem,
615
+ use_safetensors=True
616
+ if autogptq is True
617
+ else autogptq.endswith(".safetensors"),
618
+ **model_kwargs,
619
+ )
620
+
621
+ if gptqmodel:
622
+ try:
623
+ from gptqmodel import GPTQModel
624
+ except ModuleNotFoundError as exception:
625
+ raise type(exception)(
626
+ "Tried to load gptqmodel, but gptqmodel is not installed ",
627
+ "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
628
+ )
629
+
630
+ self._model = GPTQModel.from_quantized(
631
+ pretrained, trust_remote_code=trust_remote_code, **model_kwargs
632
+ )
633
+
634
+ if peft and delta:
635
+ raise ValueError(
636
+ "Cannot use both 'peft' and 'delta' options at the same time."
637
+ )
638
+
639
+ if peft:
640
+ if model_kwargs.get("load_in_4bit", None):
641
+ if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
642
+ raise AssertionError("load_in_4bit requires peft >= 0.4.0")
643
+ if self._model.config.vocab_size != len(self.tokenizer):
644
+ # resize model for LoRAs with added tokens
645
+ eval_logger.info(
646
+ f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
647
+ )
648
+ self._model.resize_token_embeddings(len(self.tokenizer))
649
+ self._model = PeftModel.from_pretrained(
650
+ self._model, peft, revision=revision
651
+ )
652
+ elif delta:
653
+ if autogptq:
654
+ eval_logger.warning(
655
+ "Delta weights might trigger unexpected behavior when used with AutoGPTQ."
656
+ )
657
+ _model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
658
+ delta,
659
+ revision=revision,
660
+ torch_dtype=get_dtype(dtype),
661
+ trust_remote_code=trust_remote_code,
662
+ **model_kwargs,
663
+ )
664
+ for name, param in self._model.state_dict().items():
665
+ try:
666
+ param.data += _model_delta.state_dict()[name]
667
+ except KeyError:
668
+ raise KeyError(f"Delta model is missing weights for layer: {name}")
669
+ except Exception as e:
670
+ raise RuntimeError(
671
+ f"Failed to add delta weights to layer {name}. Error: {e}"
672
+ )
673
+
674
+ del _model_delta
675
+
676
+ return None
677
+
678
+ def _create_tokenizer(
679
+ self,
680
+ pretrained: Union[str, transformers.PreTrainedModel],
681
+ tokenizer: Optional[
682
+ Union[
683
+ str,
684
+ transformers.PreTrainedTokenizer,
685
+ transformers.PreTrainedTokenizerFast,
686
+ ]
687
+ ],
688
+ revision: Optional[str] = "main",
689
+ trust_remote_code: Optional[bool] = False,
690
+ use_fast_tokenizer: Optional[bool] = True,
691
+ gguf_file: Optional[str] = None,
692
+ add_bos_token: Optional[bool] = False,
693
+ ) -> None:
694
+ """
695
+ Helper method during initialization.
696
+
697
+ Create a tokenizer object corresponding to the correct
698
+ tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
699
+ """
700
+ kwargs = {
701
+ "revision": revision,
702
+ "trust_remote_code": trust_remote_code,
703
+ }
704
+
705
+ # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
706
+ if gguf_file is not None:
707
+ kwargs["gguf_file"] = gguf_file
708
+ else:
709
+ kwargs["use_fast"] = use_fast_tokenizer
710
+
711
+ if add_bos_token:
712
+ kwargs["add_bos_token"] = True
713
+
714
+ if tokenizer:
715
+ if isinstance(tokenizer, str):
716
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
717
+ tokenizer, **kwargs
718
+ )
719
+ else:
720
+ assert isinstance(
721
+ tokenizer, transformers.PreTrainedTokenizer
722
+ ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
723
+ self.tokenizer = tokenizer
724
+ else:
725
+ # Get tokenizer based on 'pretrained'
726
+ if isinstance(pretrained, str):
727
+ model_name = pretrained
728
+ else:
729
+ # get the HF hub name via accessor on model
730
+ model_name = self.model.name_or_path
731
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
732
+ model_name, **kwargs
733
+ )
734
+ return None
735
+
736
+ def _detect_batch_size(self, requests=None, pos: int = 0):
737
+ if requests:
738
+ _, context_enc, continuation_enc = requests[pos]
739
+ max_length = len(
740
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
741
+ )
742
+ max_context_enc = len(context_enc[-(self.max_length + 1) :])
743
+ max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
744
+ else:
745
+ max_length = self.max_length
746
+ max_context_enc = max_length
747
+ max_cont_enc = max_length
748
+
749
+ # if OOM, then halves batch_size and tries again
750
+ @find_executable_batch_size(starting_batch_size=self.max_batch_size)
751
+ def forward_batch(batch_size):
752
+ if self.backend == "seq2seq":
753
+ length = max(max_context_enc, max_cont_enc)
754
+ batched_conts = torch.ones(
755
+ (batch_size, length), device=self.device
756
+ ).long()
757
+ test_batch = torch.ones((batch_size, length), device=self.device).long()
758
+ call_kwargs = {
759
+ "attn_mask": test_batch,
760
+ "labels": batched_conts,
761
+ }
762
+ else:
763
+ call_kwargs = {}
764
+ test_batch = torch.ones(
765
+ (batch_size, max_length), device=self.device
766
+ ).long()
767
+ for _ in range(5):
768
+ out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
769
+
770
+ return batch_size
771
+
772
+ try:
773
+ batch_size = forward_batch()
774
+ except RuntimeError as e:
775
+ if "No executable batch size found" in str(e):
776
+ batch_size = 1
777
+ else:
778
+ raise
779
+
780
+ if self.world_size > 1:
781
+ # if multi-GPU, always take minimum over all selected batch sizes
782
+ max_rnk_bs = torch.tensor([batch_size], device=self.device)
783
+ gathered = (
784
+ self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
785
+ )
786
+ batch_size = min(gathered)
787
+ clear_torch_cache()
788
+ return batch_size
789
+
790
+ clear_torch_cache()
791
+ return batch_size
792
+
793
+ def tok_encode(
794
+ self, string: str, left_truncate_len=None, add_special_tokens=None
795
+ ) -> List[int]:
796
+ """ """
797
+ # default for None - empty dict, use predefined tokenizer param
798
+ # used for all models except for CausalLM or predefined value
799
+ special_tokens_kwargs = {}
800
+
801
+ # by default for CausalLM - false or self.add_bos_token is set
802
+ if add_special_tokens is None:
803
+ if self.backend == "causal":
804
+ special_tokens_kwargs = {
805
+ "add_special_tokens": False or self.add_bos_token
806
+ }
807
+ # otherwise the method explicitly defines the value
808
+ else:
809
+ special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
810
+
811
+ encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
812
+
813
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
814
+ if left_truncate_len:
815
+ encoding = encoding[-left_truncate_len:]
816
+
817
+ return encoding
818
+
819
+ def tok_batch_encode(
820
+ self,
821
+ strings: List[str],
822
+ padding_side: str = "left",
823
+ left_truncate_len: int = None,
824
+ truncation: bool = False,
825
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
826
+ # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
827
+ old_padding_side = self.tokenizer.padding_side
828
+ self.tokenizer.padding_side = padding_side
829
+
830
+ add_special_tokens = {}
831
+ if self.backend == "causal":
832
+ add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
833
+
834
+ encoding = self.tokenizer(
835
+ strings,
836
+ truncation=truncation,
837
+ padding="longest",
838
+ return_tensors="pt",
839
+ **add_special_tokens,
840
+ )
841
+ if left_truncate_len:
842
+ original_lengths = encoding["input_ids"].size(1)
843
+ if original_lengths > left_truncate_len:
844
+ eval_logger.warn(
845
+ f"Left truncation applied. Original sequence length was {original_lengths}, "
846
+ f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
847
+ )
848
+ encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
849
+ encoding["attention_mask"] = encoding["attention_mask"][
850
+ :, -left_truncate_len:
851
+ ]
852
+ self.tokenizer.padding_side = old_padding_side
853
+
854
+ return encoding["input_ids"], encoding["attention_mask"]
855
+
856
+ def tok_decode(self, tokens, skip_special_tokens=True):
857
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
858
+
859
+ def _model_call(self, inps, attn_mask=None, labels=None):
860
+ """
861
+ :param inps: torch.Tensor
862
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
863
+ [batch, sequence_ctx]. the size of sequence may vary from call to call
864
+ :param attn_mask: torch.Tensor, optional
865
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
866
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
867
+ :param labels: torch.Tensor, optional
868
+ A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
869
+ (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
870
+ :return
871
+ A torch tensor of shape [batch, sequence, vocab] with the
872
+ logits returned from the model's decoder
873
+ """
874
+ with torch.no_grad():
875
+ if attn_mask is not None or labels is not None:
876
+ assert attn_mask is not None and labels is not None
877
+ assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
878
+ return self.model(
879
+ input_ids=inps, attention_mask=attn_mask, labels=labels
880
+ ).logits
881
+ else:
882
+ assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
883
+ return self.model(inps).logits
884
+
885
+ def _model_generate(self, context, max_length, stop, **generation_kwargs):
886
+ # temperature = 0.0 if not set
887
+ # if do_sample is false and temp==0.0:
888
+ # remove temperature, as do_sample=False takes care of this
889
+ # and we don't want a warning from HF
890
+ generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
891
+ do_sample = generation_kwargs.get("do_sample", None)
892
+
893
+ # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
894
+ if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
895
+ generation_kwargs["do_sample"] = do_sample = False
896
+
897
+ if do_sample is False and generation_kwargs.get("temperature") == 0.0:
898
+ generation_kwargs.pop("temperature")
899
+ # build stopping criteria
900
+ stopping_criteria = stop_sequences_criteria(
901
+ self.tokenizer, stop, context.shape[1], context.shape[0]
902
+ )
903
+ return self.model.generate(
904
+ input_ids=context,
905
+ max_length=max_length,
906
+ stopping_criteria=stopping_criteria,
907
+ pad_token_id=self.tokenizer.pad_token_id,
908
+ use_cache=True,
909
+ **generation_kwargs,
910
+ )
911
+
912
+ def _select_cont_toks(
913
+ self, logits: torch.Tensor, contlen: int = None, inplen: int = None
914
+ ) -> torch.Tensor:
915
+ if self.backend == "causal":
916
+ assert contlen and inplen, (
917
+ "Must pass input len and cont. len to select scored logits for causal LM"
918
+ )
919
+ # discard right-padding.
920
+ # also discard the input/context tokens. we'll only score continuations.
921
+ logits = logits[inplen - contlen : inplen]
922
+ elif self.backend == "seq2seq":
923
+ assert contlen and not inplen, (
924
+ "Selecting scored logits for Seq2SeqLM requires only cont. len"
925
+ )
926
+ # only discard right-padding.
927
+ # the logits input to this fn only contain decoder-side tokens.
928
+ logits = logits[:contlen]
929
+
930
+ return logits
931
+
932
+ def loglikelihood_rolling(
933
+ self, requests: List[Instance], disable_tqdm: bool = False
934
+ ) -> List[float]:
935
+ adaptive_batch_size = None
936
+ if self.batch_size == "auto":
937
+ # using rolling window with maximum context
938
+ print("Passed argument batch_size = auto. Detecting largest batch size")
939
+ batch_size = self._detect_batch_size()
940
+ print(f"Determined Largest batch size: {batch_size}")
941
+ adaptive_batch_size = batch_size
942
+
943
+ # First, collect all windows from all requests
944
+ all_windows = [] # List of (request_idx, window) tuples
945
+ request_window_counts = [] # Track number of windows per request
946
+
947
+ for req_idx, (string,) in enumerate(
948
+ tqdm(
949
+ [req.args for req in requests],
950
+ disable=(disable_tqdm or (self.rank != 0)),
951
+ )
952
+ ):
953
+ rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
954
+ map(
955
+ utils.make_disjoint_window,
956
+ utils.get_rolling_token_windows(
957
+ token_list=self.tok_encode(string),
958
+ prefix_token=self.prefix_token_id,
959
+ max_seq_len=self.max_length,
960
+ context_len=1,
961
+ ),
962
+ )
963
+ )
964
+
965
+ # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
966
+ windows = [(None,) + x for x in rolling_token_windows]
967
+
968
+ # Store windows with their request index
969
+ all_windows.extend((req_idx, window) for window in windows)
970
+ request_window_counts.append(len(windows))
971
+
972
+ # Handle distributed case padding
973
+ pad_amnt = 0
974
+ if self.world_size > 1:
975
+ mytensor = torch.tensor(len(all_windows), device=self.device)
976
+ gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
977
+ pad_amnt = max(gathered) - gathered[self.rank]
978
+ if pad_amnt > 0:
979
+ all_windows += pad_amnt * [all_windows[0]]
980
+
981
+ all_nlls = []
982
+ batch_size = adaptive_batch_size or self.batch_size
983
+ for i in range(0, len(all_windows), batch_size):
984
+ batch = all_windows[i : i + batch_size]
985
+ # Extract just the windows for processing, keeping track of request indices
986
+ batch_indices, batch_windows = zip(*batch)
987
+
988
+ batch_nlls = self._loglikelihood_tokens(
989
+ requests=batch_windows,
990
+ disable_tqdm=False,
991
+ override_bs=len(batch_windows),
992
+ )
993
+ # Store results with their request indices
994
+ all_nlls.extend(zip(batch_indices, batch_nlls))
995
+
996
+ # Remove padding if necessary
997
+ if (self.world_size > 1) and (pad_amnt > 0):
998
+ all_nlls = all_nlls[:-pad_amnt]
999
+
1000
+ # Reconstruct per-request loglikelihoods
1001
+ loglikelihoods = []
1002
+ current_idx = 0
1003
+ for window_count in request_window_counts:
1004
+ # Get all nlls for this request
1005
+ request_nlls = all_nlls[current_idx : current_idx + window_count]
1006
+ # Sum up the nlls for this request (discarding is_greedy)
1007
+ request_total = sum(nll[0] for _, nll in request_nlls)
1008
+ loglikelihoods.append(request_total)
1009
+ current_idx += window_count
1010
+
1011
+ string = requests[len(loglikelihoods) - 1].args[0]
1012
+ self.cache_hook.add_partial(
1013
+ "loglikelihood_rolling", (string,), request_total
1014
+ )
1015
+
1016
+ return loglikelihoods
1017
+
1018
+ def _batch_scheduler(self, pos, n_reordered_requests):
1019
+ sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
1020
+ if sched in self.batch_sizes:
1021
+ return self.batch_sizes[sched]
1022
+ if (len(self.batch_sizes) > 1) and (
1023
+ self.batch_sizes[sched - 1] == self.max_batch_size
1024
+ ):
1025
+ # if previous batch size is already maximal, skip recomputation
1026
+ self.batch_sizes[sched] = self.max_batch_size
1027
+ return self.batch_sizes[sched]
1028
+ print(
1029
+ f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
1030
+ )
1031
+ self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
1032
+ print(f"Determined largest batch size: {self.batch_sizes[sched]}")
1033
+ return self.batch_sizes[sched]
1034
+
1035
+ def _loglikelihood_tokens(
1036
+ self,
1037
+ requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
1038
+ disable_tqdm: bool = False,
1039
+ override_bs: int = None,
1040
+ ) -> List[Tuple[float, bool]]:
1041
+ # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
1042
+ res = []
1043
+
1044
+ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
1045
+ """Defines the key for the sorted method"""
1046
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
1047
+ # - time estimates will always be over not underestimates, which is more useful for planning
1048
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
1049
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
1050
+ # automatic adaptive batches much much easier to implement
1051
+ # - any OOMs will happen right away rather than near the end
1052
+
1053
+ toks = req[1] + req[2]
1054
+ return -len(toks), tuple(toks)
1055
+
1056
+ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
1057
+ """Defines the key to group and lookup one-token continuations"""
1058
+ # Use with group_by="contexts" (optional)"
1059
+ # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
1060
+ # speeds up some multiple-choice tasks proportionally to the number of choices.
1061
+ # groups requests by context+continuation[:-1] and infer on one request/group.
1062
+ return req[-2] + req[-1][:-1]
1063
+
1064
+ re_ord = Collator(
1065
+ requests,
1066
+ sort_fn=_collate,
1067
+ group_by="contexts"
1068
+ if self.backend == "causal" and self.logits_cache
1069
+ else None,
1070
+ group_fn=_lookup_one_token_cont,
1071
+ )
1072
+
1073
+ # automatic (variable) batch size detection for vectorization
1074
+ # pull longest context sample from request
1075
+ n_reordered_requests = len(re_ord)
1076
+ batch_size = (
1077
+ self.batch_size
1078
+ if self.batch_size != "auto"
1079
+ else override_bs
1080
+ if override_bs is not None
1081
+ else 0
1082
+ )
1083
+ batch_fn = (
1084
+ self._batch_scheduler
1085
+ if self.batch_size == "auto"
1086
+ and n_reordered_requests > 0
1087
+ and not override_bs
1088
+ else None
1089
+ )
1090
+
1091
+ chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
1092
+ pbar = tqdm(
1093
+ total=len(requests),
1094
+ disable=(disable_tqdm or (self.rank != 0)),
1095
+ desc="Running loglikelihood requests",
1096
+ )
1097
+ for chunk in chunks:
1098
+ inps = []
1099
+ cont_toks_list = []
1100
+ inplens = []
1101
+
1102
+ conts = []
1103
+ encoder_attns = []
1104
+
1105
+ padding_len_inp = None
1106
+ padding_len_cont = None
1107
+ # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
1108
+ # tensors, then we pack them together into a batch, call the model, and then pick it all apart
1109
+ # again because vectorizing is annoying
1110
+
1111
+ for _, context_enc, continuation_enc in chunk:
1112
+ # sanity check
1113
+ assert len(context_enc) > 0
1114
+ assert len(continuation_enc) > 0
1115
+ assert len(continuation_enc) <= self.max_length
1116
+
1117
+ # how this all works (illustrated on a causal decoder-only setup):
1118
+ # CTX CONT
1119
+ # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
1120
+ # model \ \
1121
+ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
1122
+ # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
1123
+
1124
+ # when too long to fit in context, truncate from the left
1125
+ if self.backend == "causal":
1126
+ total_length = len(context_enc) + len(continuation_enc)
1127
+ if total_length > self.max_length + 1:
1128
+ eval_logger.warn(
1129
+ f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
1130
+ f"exceeds model's maximum length ({self.max_length}). "
1131
+ f"Truncating {total_length - self.max_length + 1} tokens from the left."
1132
+ )
1133
+ inp = torch.tensor(
1134
+ (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
1135
+ dtype=torch.long,
1136
+ device=self.device,
1137
+ )
1138
+ (inplen,) = inp.shape
1139
+ elif self.backend == "seq2seq":
1140
+ inp = torch.tensor(
1141
+ (context_enc)[-self.max_length :],
1142
+ dtype=torch.long,
1143
+ device=self.device,
1144
+ )
1145
+ (inplen,) = inp.shape
1146
+
1147
+ # build encoder attn masks
1148
+ encoder_attns.append(torch.ones_like(inp))
1149
+
1150
+ cont = torch.tensor(
1151
+ (continuation_enc)[-self.max_length :],
1152
+ # TODO: left-shift these?
1153
+ # TODO: our code assumes we never end up truncating conts for either model type
1154
+ dtype=torch.long,
1155
+ device=self.device,
1156
+ )
1157
+ (contlen,) = cont.shape
1158
+
1159
+ conts.append(cont)
1160
+
1161
+ padding_len_cont = (
1162
+ max(padding_len_cont, contlen)
1163
+ if padding_len_cont is not None
1164
+ else contlen
1165
+ )
1166
+
1167
+ padding_len_inp = (
1168
+ max(padding_len_inp, inplen)
1169
+ if padding_len_inp is not None
1170
+ else inplen
1171
+ )
1172
+
1173
+ inps.append(inp) # [1, inp_length]
1174
+ cont_toks_list.append(continuation_enc)
1175
+ inplens.append(inplen)
1176
+
1177
+ # create encoder attn mask and batched conts, if seq2seq
1178
+ call_kwargs = {}
1179
+ if self.backend == "causal":
1180
+ batched_inps = pad_and_concat(
1181
+ padding_len_inp, inps, padding_side="right"
1182
+ ) # [batch, padding_len_inp]
1183
+ elif self.backend == "seq2seq":
1184
+ # TODO: left-pad encoder inps and mask?
1185
+ batched_inps = pad_and_concat(
1186
+ padding_len_inp, inps
1187
+ ) # [batch, padding_len_inp]
1188
+ batched_conts = pad_and_concat(
1189
+ padding_len_cont, conts
1190
+ ) # [batch, padding_len_cont]
1191
+ batched_encoder_mask = pad_and_concat(
1192
+ padding_len_inp, encoder_attns
1193
+ ) # [batch, padding_len_inp]
1194
+ call_kwargs = {
1195
+ "attn_mask": batched_encoder_mask,
1196
+ "labels": batched_conts,
1197
+ }
1198
+
1199
+ multi_logits = F.log_softmax(
1200
+ self._model_call(batched_inps, **call_kwargs), dim=-1
1201
+ ) # [batch, padding_length (inp or cont), vocab]
1202
+
1203
+ for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
1204
+ chunk, multi_logits, inplens, cont_toks_list
1205
+ ):
1206
+ # Slice to original seq length
1207
+ contlen = len(cont_toks)
1208
+ # take only logits in the continuation
1209
+ # (discard context toks if decoder-only ; discard right-padding)
1210
+ # also discards + checks for "virtual tokens" in the causal LM's input window
1211
+ # from prompt/prefix tuning tokens, if applicable
1212
+ ctx_len = (
1213
+ inplen + (logits.shape[0] - padding_len_inp)
1214
+ if self.backend == "causal"
1215
+ else None
1216
+ )
1217
+ logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
1218
+ logits = logits.unsqueeze(0) # [1, seq, vocab]
1219
+
1220
+ # Check if per-token argmax is exactly equal to continuation
1221
+ greedy_tokens = logits.argmax(dim=-1)
1222
+
1223
+ # check for one-token continuation cache hits.
1224
+ # noop in case group_by != "contexts" or no cache hit and returns the
1225
+ # original args. Otherwise, expands the logits batch dimension and yields each
1226
+ # batch along with matching continuation tokens and prompt strings.
1227
+ # logits -> [1, seq, vocab]
1228
+ for request_str, cont_toks, logits in re_ord.get_cache(
1229
+ req_str=request_str,
1230
+ cxt_toks=ctx_tokens,
1231
+ cont_toks=cont_toks,
1232
+ logits=logits,
1233
+ ):
1234
+ cont_toks = torch.tensor(
1235
+ cont_toks, dtype=torch.long, device=self.device
1236
+ ).unsqueeze(0) # [1, seq]
1237
+ max_equal = (greedy_tokens == cont_toks).all()
1238
+
1239
+ # Obtain log-probs at the corresponding continuation token indices
1240
+ # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
1241
+ logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
1242
+ -1
1243
+ ) # [1, seq]
1244
+
1245
+ # Answer: (log prob, is-exact-match)
1246
+ answer = (float(logits.sum()), bool(max_equal))
1247
+
1248
+ res.append(answer)
1249
+
1250
+ if request_str is not None:
1251
+ # special case: loglikelihood_rolling produces a number of loglikelihood requests
1252
+ # all with cache key None. instead do add_partial on the per-example level
1253
+ # in the loglikelihood_rolling() function for those.
1254
+ self.cache_hook.add_partial(
1255
+ "loglikelihood", request_str, answer
1256
+ )
1257
+ pbar.update(1)
1258
+
1259
+ pbar.close()
1260
+
1261
+ return re_ord.get_original(res)
1262
+
1263
+ def generate_until(
1264
+ self, requests: List[Instance], disable_tqdm: bool = False
1265
+ ) -> List[str]:
1266
+ res = []
1267
+
1268
+ def _collate(req: Tuple[str, dict]):
1269
+ """Defines the key for the sorted method"""
1270
+ # the negative sign on len(toks) sorts descending - this has a few advantages:
1271
+ # - time estimates will always be over not underestimates, which is more useful for planning
1272
+ # - to know the size of a batch when going through the list, you know the first one is always the batch
1273
+ # padded context length. this is useful to simplify the batching logic and more importantly to make
1274
+ # automatic adaptive batches much much easier to implement
1275
+ # - any OOMs will happen right away rather than near the end
1276
+ toks = self.tok_encode(req[0])
1277
+ return -len(toks), req[0]
1278
+
1279
+ pbar = tqdm(
1280
+ total=len(requests),
1281
+ disable=(disable_tqdm or (self.rank != 0)),
1282
+ desc="Running generate_until requests",
1283
+ )
1284
+ adaptive_batch_size = None
1285
+ if self.batch_size == "auto":
1286
+ # using rolling window with maximum context
1287
+ print("Passed argument batch_size = auto. Detecting largest batch size")
1288
+ batch_size = self._detect_batch_size()
1289
+ print(f"Determined Largest batch size: {batch_size}")
1290
+ adaptive_batch_size = batch_size
1291
+ # for each different set of kwargs, we execute all requests, by batch.
1292
+ batch_size = (
1293
+ self.batch_size
1294
+ if self.batch_size != "auto"
1295
+ else adaptive_batch_size
1296
+ if adaptive_batch_size is not None
1297
+ else 0
1298
+ )
1299
+ batch_fn = (
1300
+ self._batch_scheduler
1301
+ if self.batch_size == "auto" and not adaptive_batch_size
1302
+ else None
1303
+ )
1304
+
1305
+ # we group requests by their generation_kwargs,
1306
+ # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
1307
+ # in the same batch.
1308
+ # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
1309
+ re_ords = Collator(
1310
+ [reg.args for reg in requests],
1311
+ sort_fn=_collate,
1312
+ group_by="gen_kwargs",
1313
+ group_fn=lambda x: x[1],
1314
+ )
1315
+ chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
1316
+ eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
1317
+ for chunk in chunks:
1318
+ contexts, all_gen_kwargs = zip(*chunk)
1319
+ # we assume all gen kwargs in the batch are the same
1320
+ # this is safe to assume because the `grouper` object ensures it.
1321
+ gen_kwargs = all_gen_kwargs[0]
1322
+ # unpack our keyword arguments.
1323
+ if isinstance(gen_kwargs, dict):
1324
+ kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
1325
+ # add EOS token to stop sequences
1326
+ until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
1327
+ else:
1328
+ raise ValueError(
1329
+ f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
1330
+ )
1331
+ if "max_gen_toks" in kwargs.keys():
1332
+ max_gen_toks = kwargs.pop("max_gen_toks")
1333
+ else:
1334
+ max_gen_toks = self.max_gen_toks
1335
+
1336
+ # set the max length in tokens of inputs ("context_enc")
1337
+ if self.backend == "causal":
1338
+ # max len for inputs = max length, minus room to generate the max new tokens
1339
+ max_ctx_len = self.max_length - max_gen_toks
1340
+ assert max_ctx_len > 0, (
1341
+ f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
1342
+ )
1343
+ elif self.backend == "seq2seq":
1344
+ # max len for inputs = encoder's whole max_length
1345
+ max_ctx_len = self.max_length
1346
+
1347
+ # encode, pad, and truncate contexts for this batch
1348
+ context_enc, attn_masks = self.tok_batch_encode(
1349
+ contexts,
1350
+ left_truncate_len=max_ctx_len,
1351
+ truncation=self.truncation,
1352
+ )
1353
+ context_enc = context_enc.to(self.device)
1354
+ attn_masks = attn_masks.to(self.device)
1355
+
1356
+ if "max_length" not in kwargs:
1357
+ kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
1358
+
1359
+ # perform batched generation
1360
+ cont = self._model_generate(
1361
+ context=context_enc,
1362
+ attention_mask=attn_masks,
1363
+ stop=until,
1364
+ **kwargs,
1365
+ )
1366
+
1367
+ cont_toks_list = cont.tolist()
1368
+ for cont_toks, context in zip(cont_toks_list, contexts):
1369
+ # discard context + left-padding toks if using causal decoder-only LM
1370
+ if self.backend == "causal":
1371
+ cont_toks = cont_toks[context_enc.shape[1] :]
1372
+
1373
+ s = self.tok_decode(cont_toks)
1374
+
1375
+ # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
1376
+ for term in until:
1377
+ if len(term) > 0:
1378
+ # ignore '' separator,
1379
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
1380
+ s = s.split(term)[0]
1381
+
1382
+ res.append(s)
1383
+
1384
+ self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
1385
+ pbar.update(1)
1386
+ # reorder this group of results back to original unsorted form
1387
+ res = re_ords.get_original(res)
1388
+
1389
+ pbar.close()
1390
+
1391
+ return res
1392
+
1393
+ def apply_chat_template(
1394
+ self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
1395
+ ) -> str:
1396
+ """
1397
+ Method to apply a chat template to a list of chat history between user and model.
1398
+ """
1399
+ try:
1400
+ chat_templated = self.tokenizer.apply_chat_template(
1401
+ chat_history,
1402
+ tokenize=False,
1403
+ add_generation_prompt=add_generation_prompt,
1404
+ continue_final_message=not add_generation_prompt,
1405
+ )
1406
+ except jinja2.exceptions.TemplateError:
1407
+ eval_logger.warning(
1408
+ "Failed to apply chat template. removing the system role in chat history."
1409
+ )
1410
+ chat_history = [msg for msg in chat_history if msg["role"] != "system"]
1411
+ chat_templated = self.tokenizer.apply_chat_template(
1412
+ chat_history,
1413
+ tokenize=False,
1414
+ add_generation_prompt=add_generation_prompt,
1415
+ continue_final_message=not add_generation_prompt,
1416
+ )
1417
+
1418
+ return chat_templated
1419
+
1420
+ def get_model_info(self) -> dict:
1421
+ """
1422
+ Method to get Hugging Face model information for experiment reproducibility.
1423
+ """
1424
+
1425
+ def get_model_num_params(model) -> int:
1426
+ if hasattr(model, "num_parameters"):
1427
+ return model.num_parameters()
1428
+ if hasattr(model, "parameters"):
1429
+ return sum(p.numel() for p in model.parameters())
1430
+ else:
1431
+ return -1
1432
+
1433
+ def get_model_dtype(model) -> str:
1434
+ if hasattr(model, "dtype"):
1435
+ return model.dtype
1436
+ else:
1437
+ return ""
1438
+
1439
+ def get_model_sha(pretrained: str, revision: str) -> str:
1440
+ try:
1441
+ model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
1442
+ return model_info.sha
1443
+ except Exception as e:
1444
+ eval_logger.debug(
1445
+ f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
1446
+ )
1447
+ return ""
1448
+
1449
+ model_info = {
1450
+ "model_num_parameters": get_model_num_params(self._model),
1451
+ "model_dtype": get_model_dtype(self._model),
1452
+ "model_revision": self.revision,
1453
+ "model_sha": get_model_sha(self.pretrained, self.revision),
1454
+ }
1455
+ if self.peft:
1456
+ model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
1457
+ if self.delta:
1458
+ model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
1459
+ return model_info
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/utils.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import fnmatch
3
+ import gc
4
+ import itertools
5
+ import logging
6
+ import time
7
+ from functools import wraps
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ Iterable,
14
+ Iterator,
15
+ List,
16
+ Literal,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ Union,
21
+ )
22
+
23
+ import torch
24
+ import transformers
25
+
26
+
27
+ eval_logger = logging.getLogger(__name__)
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PreTrainedTokenizerBase
32
+ from transformers.configuration_utils import PretrainedConfig
33
+
34
+
35
+ def chunks(iter, n: int = 0, fn=None):
36
+ """
37
+ Divides an iterable into chunks of specified size or based on a given function.
38
+ Useful for batching
39
+
40
+ Parameters:
41
+ - iter: The input iterable to be divided into chunks.
42
+ - n: An integer representing the size of each chunk. Default is 0.
43
+ - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
44
+
45
+ Returns:
46
+ An iterator that yields chunks of the input iterable.
47
+
48
+ Example usage:
49
+ ```
50
+ data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
51
+ for chunk in chunks(data, 3):
52
+ print(chunk)
53
+ ```
54
+ Output:
55
+ ```
56
+ [1, 2, 3]
57
+ [4, 5, 6]
58
+ [7, 8, 9]
59
+ [10]
60
+ ```
61
+ """
62
+ arr = []
63
+ for i, x in enumerate(iter):
64
+ arr.append(x)
65
+ if len(arr) == (fn(i, iter) if fn else n):
66
+ yield arr
67
+ arr = []
68
+
69
+ if arr:
70
+ yield arr
71
+
72
+
73
+ class MultiChoice:
74
+ def __init__(self, choices) -> None:
75
+ self.choices = choices
76
+
77
+ # Simple wildcard support (linux filename patterns)
78
+ def __contains__(self, values) -> bool:
79
+ for value in values.split(","):
80
+ if len(fnmatch.filter(self.choices, value)) == 0:
81
+ eval_logger.info("Available tasks to choose:")
82
+ for choice in self.choices:
83
+ eval_logger.info(f" - {choice}")
84
+ raise ValueError("'{}' is not in task list".format(value))
85
+ return True
86
+
87
+ def __iter__(self) -> Iterator:
88
+ for choice in self.choices:
89
+ yield choice
90
+
91
+
92
+ class Grouper:
93
+ """
94
+ takes an array `arr` and function `fn` and returns a dictionary
95
+ with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
96
+ objects in `arr` satisfying `key == fn(ob)`.
97
+ """
98
+
99
+ def __init__(self, arr, fn) -> None:
100
+ # self.orig_arr = arr
101
+ self.size = len(arr)
102
+ arr = list(enumerate(arr))
103
+
104
+ def group_return_dict(arr, fn):
105
+ res = collections.defaultdict(list)
106
+
107
+ for ob in arr:
108
+ res[fn(ob)].append(ob)
109
+ return res
110
+
111
+ arr = group_return_dict(arr, lambda x: fn(x[1]))
112
+
113
+ # self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
114
+ self.arr = arr
115
+ self._grouped = None
116
+
117
+ def get_grouped(self):
118
+ # return the contents but not indices for our grouped dict.
119
+ if self._grouped:
120
+ return self._grouped
121
+ grouped = {}
122
+ for key in self.arr.keys():
123
+ # drop the index from each element of self.arr
124
+ grouped[key] = [y[1] for y in self.arr[key]]
125
+ self._grouped = grouped
126
+ return grouped
127
+
128
+ def get_original(self, grouped_dict):
129
+ # take in a grouped dictionary with e.g. results for each key listed
130
+ # in the same order as the instances in `self.arr`, and
131
+ # return the results in the same (single list) order as `self.orig_arr`.
132
+ res = [None] * self.size
133
+ cov = [False] * self.size
134
+ # orig = [None] * self.size
135
+
136
+ assert grouped_dict.keys() == self.arr.keys()
137
+
138
+ for key in grouped_dict.keys():
139
+ for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
140
+ res[ind] = v
141
+ cov[ind] = True
142
+ # orig[ind] = _
143
+
144
+ assert all(cov)
145
+ # assert orig == self.orig_arr
146
+
147
+ return res
148
+
149
+
150
+ def pad_and_concat(
151
+ max_length: int,
152
+ tensors: List[torch.Tensor],
153
+ padding_side: Literal["right", "left"] = "right",
154
+ ):
155
+ """
156
+ Method for padding a list of tensors given the maximum tensor
157
+ length in the batch. Used for batching inputs and continuations in
158
+ seq2seq models.
159
+ """
160
+ assert padding_side == "left" or padding_side == "right", (
161
+ f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
162
+ )
163
+
164
+ for i, tensor in enumerate(tensors):
165
+ if len(tensor.shape) == 2:
166
+ tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
167
+ tensor_len = tensor.shape[0]
168
+ if tensor_len < max_length:
169
+ if padding_side == "right":
170
+ # right-pad
171
+ tensors[i] = torch.cat(
172
+ [
173
+ tensor, # [seq]
174
+ torch.zeros(
175
+ max_length - tensor_len,
176
+ dtype=torch.long,
177
+ device=tensor.device,
178
+ ), # [padding_length - seq]
179
+ ],
180
+ dim=0,
181
+ ).unsqueeze(0)
182
+ else:
183
+ # left-pad
184
+ tensors[i] = torch.cat(
185
+ [
186
+ torch.zeros(
187
+ max_length - tensor_len,
188
+ dtype=torch.long,
189
+ device=tensor.device,
190
+ ), # [padding_length - seq]
191
+ tensor, # [seq]
192
+ ],
193
+ dim=0,
194
+ ).unsqueeze(0)
195
+ else:
196
+ tensors[i] = tensor.unsqueeze(0)
197
+
198
+ return torch.cat(tensors, dim=0)
199
+
200
+
201
+ def clear_torch_cache() -> None:
202
+ gc.collect()
203
+ torch.cuda.empty_cache()
204
+
205
+
206
+ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
207
+ """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
208
+ if isinstance(dtype, str) and dtype != "auto":
209
+ # Convert `str` args torch dtype: `float16` -> `torch.float16`
210
+ _torch_dtype = getattr(torch, dtype)
211
+ else:
212
+ _torch_dtype = dtype
213
+ return _torch_dtype
214
+
215
+
216
+ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
217
+ """Criteria to stop on the specified multi-token sequence."""
218
+
219
+ def __init__(
220
+ self,
221
+ sequence: str,
222
+ tokenizer: transformers.PreTrainedTokenizer,
223
+ initial_decoder_input_length: int,
224
+ batch_size: int,
225
+ ) -> None:
226
+ self.initial_decoder_input_length = initial_decoder_input_length
227
+ self.done_tracker = [False] * batch_size
228
+ self.sequence = sequence
229
+ self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
230
+ # print(sequence, self.sequence_ids)
231
+ # we look back for 2 more tokens than it takes to encode our stop sequence
232
+ # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
233
+ # and we don't want to mistakenly not stop a generation because our
234
+ # (string) stop sequence was output in a different tokenization
235
+
236
+ # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
237
+ # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
238
+ # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
239
+ self.sequence_id_len = len(self.sequence_ids) + 2
240
+ self.tokenizer = tokenizer
241
+
242
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
243
+ # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
244
+ lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
245
+
246
+ lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
247
+
248
+ lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
249
+
250
+ for i, done in enumerate(self.done_tracker):
251
+ if not done:
252
+ self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
253
+ return False not in self.done_tracker
254
+
255
+
256
+ def stop_sequences_criteria(
257
+ tokenizer: transformers.PreTrainedTokenizer,
258
+ stop_sequences: List[str],
259
+ initial_decoder_input_length: int,
260
+ batch_size: int,
261
+ ) -> transformers.StoppingCriteriaList:
262
+ return transformers.StoppingCriteriaList(
263
+ [
264
+ *[
265
+ MultiTokenEOSCriteria(
266
+ sequence, tokenizer, initial_decoder_input_length, batch_size
267
+ )
268
+ for sequence in stop_sequences
269
+ ],
270
+ ]
271
+ )
272
+
273
+
274
+ def undistribute(iterable):
275
+ """
276
+ Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute .
277
+
278
+ Re-interleaves results that have been split using more_itertools.distribute:
279
+ >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6])
280
+ >>> list(group_1)
281
+ [1, 3, 5]
282
+ >>> list(group_2)
283
+ [2, 4, 6]
284
+ >>> undistribute([group_1, group_2])
285
+ [1, 2, 3, 4, 5, 6]
286
+
287
+ Handles non-uniform component lengths:
288
+
289
+ >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7])
290
+ >>> [list(c) for c in children]
291
+ [[1, 4, 7], [2, 5], [3, 6]]
292
+ >>> undistribute(children)
293
+ [1, 2, 3, 4, 5, 6, 7]
294
+
295
+ Also handles when some iterables are empty:
296
+
297
+ >>> children = distribute(5, [1, 2, 3])
298
+ >>> [list(c) for c in children]
299
+ [[1], [2], [3], [], []]
300
+ >>> undistribute(children)
301
+ [1, 2, 3]
302
+
303
+ """
304
+
305
+ return [
306
+ x
307
+ for x in itertools.chain.from_iterable(
308
+ itertools.zip_longest(*[list(x) for x in iterable])
309
+ )
310
+ if x is not None
311
+ ]
312
+
313
+
314
+ def retry_on_specific_exceptions(
315
+ on_exceptions: List[Type[Exception]],
316
+ max_retries: Optional[int] = None,
317
+ backoff_time: float = 3.0,
318
+ backoff_multiplier: float = 1.5,
319
+ on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
320
+ ):
321
+ """Retry on an LLM Provider's rate limit error with exponential backoff
322
+ For example, to use for OpenAI, do the following:
323
+ ```
324
+ from openai import RateLimitError
325
+
326
+ # Recommend specifying max_retries to avoid infinite loops!
327
+ @retry_on_specific_exceptions([RateLimitError], max_retries=3)
328
+ def completion(...):
329
+ # Wrap OpenAI completion function here
330
+ ...
331
+ ```
332
+ """
333
+
334
+ def decorator(func: Callable):
335
+ @wraps(func)
336
+ def wrapper(*args, **kwargs):
337
+ sleep_time = backoff_time
338
+ attempt = 0
339
+ while max_retries is None or attempt < max_retries:
340
+ try:
341
+ return func(*args, **kwargs)
342
+ except tuple(on_exceptions) as e:
343
+ if on_exception_callback is not None:
344
+ on_exception_callback(e, sleep_time)
345
+ time.sleep(sleep_time)
346
+ sleep_time *= backoff_multiplier
347
+ attempt += 1
348
+
349
+ return wrapper
350
+
351
+ return decorator
352
+
353
+
354
+ class Collator:
355
+ """
356
+ A class for reordering and batching elements of an array.
357
+
358
+ This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
359
+
360
+ Objects of this class have the group_by attribute which determines the method for grouping
361
+ the data while batching it. Three options include "gen_kwargs", "contexts", or None:
362
+ If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs
363
+ If group_by == "contexts" then requests will be grouped by context + cont[:-1]
364
+ If None then requests will just be reordered by length descending.
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ arr: List,
370
+ sort_fn: Callable = lambda x: x,
371
+ group_fn: Callable = lambda x: x[1],
372
+ group_by: Union[Literal["gen_kwargs", "contexts"], None] = None,
373
+ ) -> None:
374
+ self._group_by = group_by
375
+ # 0 indices are enumerated indices. Apply functions to original arr.
376
+ self._sort_fn = lambda x: sort_fn(x[1])
377
+ self._group_fn = lambda x: group_fn(x[1])
378
+ self._reorder_indices: List = []
379
+ self._size = len(arr)
380
+ self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple(
381
+ enumerate(arr)
382
+ ) # [indices, (arr)]
383
+ if self._group_by == "contexts":
384
+ self._group_by_context()
385
+ elif self._group_by == "gen_kwargs":
386
+ self._group_by_index()
387
+
388
+ def _group_by_index(self) -> None:
389
+ """Group the elements of a list based on their indices."""
390
+ self._arr_with_indices = self.group(
391
+ self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs"
392
+ )
393
+
394
+ def _group_by_context(self) -> None:
395
+ """Group the array with indices by context."""
396
+ self._arr_with_indices = self.group(
397
+ self._arr_with_indices, fn=self._group_fn, group_by="contexts"
398
+ )
399
+
400
+ def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
401
+ """
402
+ Generates and yields batches from the reordered array. The method of grouping and batching
403
+ depends on the parameter `group_by`.
404
+ If `group_by` is set to "gen_kwargs", it will batch the
405
+ re-ordered values with same gen_kwargs for each batch.
406
+ If `group_by` is "contexts", it caches the requests by context before batching.
407
+ If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array
408
+
409
+ Parameters:
410
+ - n (int): The size of each batch. Defaults to 1.
411
+ - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
412
+ each batch. Optional, defaults to None.
413
+
414
+ Returns:
415
+ Iterator: An iterator over batches of reordered elements grouped as per the `group_by`
416
+ attribute.
417
+
418
+ Yields:
419
+ List of batched elements according to the `group_by` attribute.
420
+ """
421
+ if self._group_by == "gen_kwargs":
422
+ for (
423
+ key,
424
+ values,
425
+ ) in self._arr_with_indices.items(): # type: ignore
426
+ values = self._reorder(values)
427
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
428
+ yield from batch
429
+ elif self._group_by == "contexts":
430
+ # Get one sample from each key
431
+ values = self._reorder(
432
+ [value[0] for value in self._arr_with_indices.values()]
433
+ )
434
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
435
+ yield from batch
436
+ else:
437
+ values = self._reorder(self._arr_with_indices) # type: ignore
438
+ batch = self.get_chunks(values, n=n, fn=batch_fn)
439
+ yield from batch
440
+
441
+ def get_cache(
442
+ self,
443
+ req_str: Tuple[str, str] = None,
444
+ cxt_toks: List[int] = None,
445
+ cont_toks: List[int] = None,
446
+ logits: torch.Tensor = None,
447
+ ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]:
448
+ """
449
+ Retrieves cached single-token continuations and their associated arguments, updating indices as necessary.
450
+
451
+ The behavior of this function varies depending on how the `group_by` attribute is set:
452
+
453
+ - When `group_by` is "contexts":
454
+ The function identifies single-token continuations by checking for keys that equate to
455
+ [context+continuation][-1] and logs the indices for re-ordering.
456
+ In this mode, this function can work in two scenarios:
457
+
458
+ 1. Cache Hit - Single Match:
459
+ If a single matching context-continuation pair is found in the cache,
460
+ the function yields the original arguments.
461
+
462
+ 2. Cache Hit - Multiple Matches:
463
+ If multiple matching context-continuation pairs are found in the cache,
464
+ the function expands the logits batch dimension to match the number of cache hits.
465
+ It updates the original requests and continuation tokens.
466
+
467
+ - When `group_by` is not set to "contexts":
468
+ This method yields the original arguments, logits and continuation tokens,
469
+ without checking for one-token continuations.
470
+
471
+ Parameters:
472
+ - req_str (tuple[str, str]): Original strings used for CachingLM.
473
+ - cxt_toks (list[int]): Full context tokens used for lookup.
474
+ - cont_toks (list[int]): Continuation tokens for which logits were generated.
475
+ - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys.
476
+
477
+ Yields:
478
+ - Iterator:
479
+ - req_str (tuple[str, str]): strings used for CachingLM.
480
+ - cont_toks (list[int]) : continuation tokens.
481
+ - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times)
482
+ """
483
+ if self._group_by == "contexts":
484
+ cache_hit: List[
485
+ Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]]
486
+ ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1]))
487
+ if (cache_size := len(cache_hit)) == 1:
488
+ self._reorder_indices.extend(x[0] for x in cache_hit)
489
+ yield req_str, cont_toks, logits
490
+ else:
491
+ # If we have matching requests then expand the batch dimension (no-op) and
492
+ # yield each along with its corresponding args.
493
+ multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size)
494
+ indices, req_str, cont_toks = zip(
495
+ *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit]
496
+ )
497
+ self._reorder_indices.extend(indices)
498
+ for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits):
499
+ yield c_key, cont_tok, logit
500
+ else:
501
+ yield req_str, cont_toks, logits
502
+
503
+ def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator:
504
+ """
505
+ Reorders the elements in the array based on the sorting function.
506
+
507
+ Parameters:
508
+ - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered.
509
+
510
+ Yields:
511
+ Iterator
512
+ """
513
+ arr = sorted(arr, key=self._sort_fn)
514
+ if not self._group_by == "contexts":
515
+ # If grouped by contexts then indices will be set in get_cache()
516
+ self._reorder_indices.extend([x[0] for x in arr])
517
+ yield from [x[1] for x in arr]
518
+
519
+ def get_original(self, newarr: List) -> List:
520
+ """
521
+ Restores the original order of elements from the reordered list.
522
+
523
+ Parameters:
524
+ - newarr (list): The reordered array.
525
+
526
+ Returns:
527
+ list: The array with elements restored to their original order.
528
+ """
529
+ res = [None] * self._size
530
+ cov = [False] * self._size
531
+
532
+ for ind, v in zip(self._reorder_indices, newarr):
533
+ res[ind] = v
534
+ cov[ind] = True
535
+
536
+ assert all(cov)
537
+
538
+ return res
539
+
540
+ def __len__(self):
541
+ return self._size
542
+
543
+ @staticmethod
544
+ def group(
545
+ arr: Iterable,
546
+ fn: Callable,
547
+ group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs",
548
+ ) -> dict:
549
+ """
550
+ Groups elements of an iterable based on a provided function.
551
+
552
+
553
+ The `group_by` parameter determines the method of grouping.
554
+ If `group_by` is "contexts", the elements are grouped by [context + cont][:-1].
555
+ If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict.
556
+
557
+ Parameters:
558
+ - arr (Iterable): The iterable to be grouped.
559
+ - fn (Callable): The function to determine the grouping.
560
+ - values (bool): If True, returns the values of the group. Defaults to False.
561
+
562
+ Returns:
563
+ Iterator: An iterable of grouped elements.
564
+ """
565
+ res = collections.defaultdict(list)
566
+ for ob in arr:
567
+ # where ob == [context + cont]
568
+ if group_by == "contexts":
569
+ res[tuple(fn(ob))].append(ob)
570
+ else:
571
+ try:
572
+ hashable_dict = tuple(
573
+ (
574
+ key,
575
+ tuple(value)
576
+ if isinstance(value, collections.abc.Iterable)
577
+ else value,
578
+ )
579
+ for key, value in sorted(fn(ob).items())
580
+ )
581
+ res[hashable_dict].append(ob)
582
+ except (TypeError, AttributeError):
583
+ res[tuple(fn(ob))].append(ob)
584
+ return res
585
+
586
+ @staticmethod
587
+ def get_chunks(_iter, n: int = 0, fn=None):
588
+ """
589
+ Divides an iterable into chunks of specified size or based on a given function.
590
+ Useful for batching
591
+
592
+ Parameters:
593
+ - iter: The input iterable to be divided into chunks.
594
+ - n: An integer representing the size of each chunk. Default is 0.
595
+ - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
596
+
597
+ Returns:
598
+ An iterator that yields chunks of the input iterable.
599
+
600
+ Example usage:
601
+ ```
602
+ data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
603
+ for chunk in chunks(data, 3):
604
+ print(chunk)
605
+ ```
606
+ Output:
607
+ ```
608
+ [1, 2, 3]
609
+ [4, 5, 6]
610
+ [7, 8, 9]
611
+ [10]
612
+ ```
613
+ """
614
+ arr = []
615
+ _iter = tuple(_iter)
616
+ for i, x in enumerate(_iter):
617
+ arr.append(x)
618
+ if len(arr) == (fn(i, _iter) if fn else n):
619
+ yield arr
620
+ arr = []
621
+
622
+ if arr:
623
+ yield arr
624
+
625
+
626
+ def configure_pad_token(
627
+ tokenizer: "PreTrainedTokenizerBase",
628
+ model_config: Optional["PretrainedConfig"] = None,
629
+ ) -> "PreTrainedTokenizerBase":
630
+ """
631
+ This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present.
632
+ Some tokenizers require special handling.
633
+
634
+ Args:
635
+ tokenizer: The tokenizer for which the padding token is to be handled.
636
+ model_config: The configuration of the model. Default is None.
637
+
638
+ Returns:
639
+ The tokenizer after the padding token has been handled.
640
+
641
+ Raises:
642
+ AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0.
643
+ """
644
+ if tokenizer.pad_token:
645
+ pass
646
+ elif tokenizer.unk_token:
647
+ tokenizer.pad_token_id = tokenizer.unk_token_id
648
+ elif tokenizer.eos_token:
649
+ tokenizer.pad_token_id = tokenizer.eos_token_id
650
+ else:
651
+ # handle special cases
652
+ if model_config and getattr(model_config, "model_type", None) == "qwen":
653
+ # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
654
+ tokenizer.pad_token = "<|endoftext|>"
655
+ elif (
656
+ tokenizer.__class__.__name__ == "RWKVWorldTokenizer"
657
+ or tokenizer.__class__.__name__ == "Rwkv5Tokenizer"
658
+ ):
659
+ # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0)
660
+ # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer
661
+ # ---
662
+ # Note that the world tokenizer class name, might change in the future for the final huggingface merge
663
+ # https://github.com/huggingface/transformers/pull/26963
664
+ assert tokenizer.pad_token_id == 0
665
+ else:
666
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
667
+
668
+ return tokenizer
669
+
670
+
671
+ def replace_placeholders(
672
+ string: str, default_placeholder: str, image_token: str, max_images: int
673
+ ):
674
+ """
675
+ A utility function used for local multimodal models. It locates all `placeholder` string
676
+ occurrences in the given input `string_` and replaces the first `max_count` instances with
677
+ `replacement`, and all subsequent occurrences with the empty string.
678
+
679
+ This is used to replace <image> placeholder tags by model-specific image tokens like <|image_pad|>
680
+ and to allow for only the first `max_count` images to be passed to a model if desired.
681
+
682
+ :param string: The original string containing placeholders.
683
+ :param default_placeholder: The placeholder text to be replaced.
684
+ :param image_token: The token to replace the placeholder with.
685
+ :param max_images: The maximum number of replacements to make.
686
+ :return: The string with placeholders replaced.
687
+ """
688
+ count = 0
689
+ result = []
690
+
691
+ parts = string.split(default_placeholder)
692
+ for part in parts[:-1]: # Iterate through all but the last part
693
+ result.append(part)
694
+ if count < max_images:
695
+ result.append(image_token)
696
+ count += 1
697
+ elif default_placeholder != image_token:
698
+ result.append(default_placeholder)
699
+
700
+ # Add the last part of the string
701
+ result.append(parts[-1])
702
+ return "".join(result)
703
+
704
+
705
+ def flatten_image_list(images: List[List]):
706
+ """
707
+ Takes in a list of lists of images, and returns a single list of all images in order.
708
+ Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor.
709
+
710
+ :param images: A list of lists of PIL images.
711
+ :return: a list of PIL images, via concatenating all the sub-lists in order.
712
+ """
713
+ return [image for image_list in images for image in image_list]
714
+
715
+
716
+ def handle_stop_sequences(
717
+ until: Union[str, List[str], None], eos: Optional[str]
718
+ ) -> List[str]:
719
+ """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
720
+ if isinstance(until, str):
721
+ until = [until]
722
+ elif until is None:
723
+ until = []
724
+ elif not isinstance(until, list):
725
+ raise ValueError(
726
+ f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
727
+ )
728
+
729
+ if eos is not None and eos not in until:
730
+ until.append(eos)
731
+ return until
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/models/verifier.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import ast
4
+ import re
5
+ import numpy as np
6
+ import textwrap
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class CodeVerifier:
11
+ def __init__(self, model, tokenizer, device="cuda"):
12
+ self.model = model
13
+ self.tokenizer = tokenizer
14
+ self.device = device
15
+
16
+ self.yes_ids, self.no_ids = [], []
17
+ for t in ["Yes", " Yes", "YES"]:
18
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
19
+ if len(ids) > 0: self.yes_ids.append(ids[-1])
20
+ for t in ["No", " No", "NO"]:
21
+ ids = self.tokenizer.encode(t, add_special_tokens=False)
22
+ if len(ids) > 0: self.no_ids.append(ids[-1])
23
+
24
+ self.yes_ids = list(set(self.yes_ids))
25
+ self.no_ids = list(set(self.no_ids))
26
+
27
+ def _extract_python_code(self, text):
28
+ text = text.strip()
29
+ match = re.search(r"```python\s*(.*?)```", text, re.DOTALL)
30
+ if match: return match.group(1)
31
+ match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL)
32
+ if match_generic: return match_generic.group(1)
33
+ return text
34
+
35
+ def check_syntax(self, code_str):
36
+ clean_code = self._extract_python_code(code_str)
37
+ try:
38
+ if len(clean_code.strip()) < 5: return False
39
+ ast.parse(clean_code)
40
+ return True
41
+ except:
42
+ return False
43
+
44
+ def compute_confidence(self, logits):
45
+ if logits is None: return 0.0
46
+ probs = torch.softmax(logits, dim=-1)
47
+ max_probs, _ = torch.max(probs, dim=-1)
48
+ log_probs = torch.log(max_probs + 1e-10)
49
+ return torch.exp(torch.mean(log_probs)).item()
50
+
51
+ def svf_score(self, prompt, code_str, task_type="code"):
52
+
53
+ max_len = 2000
54
+ if len(code_str) > max_len:
55
+ if task_type == "reasoning":
56
+ truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):]
57
+ else:
58
+ truncated_code = code_str[-max_len:]
59
+ else:
60
+ truncated_code = code_str
61
+
62
+ if task_type == "code":
63
+ prompt_template = f"""
64
+ You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints.
65
+
66
+ [Problem Statement]
67
+ {prompt}
68
+ [/Problem Statement]
69
+
70
+ [Proposed Python Solution]
71
+ ```python
72
+ {truncated_code}
73
+ ```
74
+ [/Proposed Python Solution]
75
+
76
+ **Analysis Steps:**
77
+ 1. Correctness: Does the core algorithm correctly solve the problem?
78
+ 2. Efficiency: Is the time complexity acceptable for the given constraints?
79
+ 3. Edge Cases & Constraints: Does the code handle all rules and edge cases?
80
+
81
+ **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No.
82
+ **Answer:** """
83
+
84
+ elif task_type == "math":
85
+ prompt_template = f"""
86
+ You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy.
87
+
88
+ [Math Problem]
89
+ {prompt}
90
+ [/Math Problem]
91
+
92
+ [Proposed Mathematical Solution]
93
+ {truncated_code}
94
+ [/Proposed Mathematical Solution]
95
+
96
+ **Analysis Steps:**
97
+ 1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly?
98
+ 2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate?
99
+ 3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem?
100
+
101
+ **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No.
102
+ **Answer:** """
103
+
104
+ elif task_type == "reasoning":
105
+ prompt_template = f"""
106
+ You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question.
107
+
108
+ [Context and Question]
109
+ {prompt}
110
+ [/Context and Question]
111
+
112
+ [Proposed Answer]
113
+ {truncated_code}
114
+ [/Proposed Answer]
115
+
116
+ **Analysis Steps :**
117
+ 1. Faithfulness: Is the answer an exact, literal span from the context?
118
+ 2. Relevance: Does the answer directly address the specific question asked without hallucinating external information?
119
+ 3. Accuracy: Does the provided context strictly support this answer?
120
+
121
+ **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No.
122
+ **Answer:** """
123
+
124
+ else:
125
+ prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:"
126
+
127
+ verify_text = textwrap.dedent(prompt_template).strip()
128
+ input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device)
129
+
130
+ max_pos = getattr(self.model.config, "max_position_embeddings",
131
+ getattr(self.model.config, "n_positions",
132
+ getattr(self.model.config, "max_sequence_length", 20480)))
133
+
134
+ if input_ids.shape[1] > max_pos - 16:
135
+ logger.warning("Verifier input is too long, truncating from the left.")
136
+ input_ids = input_ids[:, -(max_pos - 16):]
137
+
138
+ with torch.no_grad():
139
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
140
+ outputs = self.model(input_ids, 'full')
141
+ logits = outputs.logits[0, -1, :]
142
+
143
+ yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf'))
144
+ no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf'))
145
+
146
+ if yes_score == -float('inf') and no_score == -float('inf'): return 0.5
147
+
148
+ probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0)
149
+ return probs[0].item()
150
+
151
+ def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"):
152
+ if mode == "svf":
153
+ return self.svf_score(prompt, code_str, task_type=task_type)
154
+ else:
155
+ return self.compute_confidence(current_logits)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/prompts/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ from typing import Dict
5
+
6
+ from lm_eval import utils
7
+
8
+
9
+ eval_logger = logging.getLogger(__name__)
10
+
11
+ # Prompt library.
12
+ # Stores prompts in a dictionary indexed by 2 levels:
13
+ # prompt category name, and prompt name.
14
+ # This allows us to access prompts
15
+ PROMPT_REGISTRY: Dict[str, Dict[str, str]] = {
16
+ "qa-basic": {
17
+ "question-newline-answer": "Question: {{question}}\nAnswer:",
18
+ "q-newline-a": "Q: {{question}}\nA:",
19
+ },
20
+ }
21
+
22
+
23
+ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
24
+ # unpack prompt name
25
+ category_name, prompt_name = prompt_id.split(":")
26
+ if subset_name is None:
27
+ dataset_full_name = dataset_name
28
+ else:
29
+ dataset_full_name = f"{dataset_name}-{subset_name}"
30
+ eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
31
+ if category_name == "promptsource":
32
+ try:
33
+ from promptsource.templates import DatasetTemplates
34
+ except ModuleNotFoundError as exception:
35
+ raise type(exception)(
36
+ "Tried to load a Promptsource template, but promptsource is not installed ",
37
+ "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]",
38
+ )
39
+ try:
40
+ if subset_name is None:
41
+ prompts = DatasetTemplates(dataset_name=dataset_name)
42
+ else:
43
+ prompts = DatasetTemplates(
44
+ dataset_name=dataset_name, subset_name=subset_name
45
+ )
46
+ except Exception:
47
+ raise ValueError(f"{dataset_name} and {subset_name} not found")
48
+ if prompt_name in prompts.all_template_names:
49
+ return prompts[prompt_name]
50
+ else:
51
+ raise ValueError(
52
+ f"{prompt_name} not in prompt list {prompts.all_template_names}"
53
+ )
54
+ elif ".yaml" in category_name:
55
+ import yaml
56
+
57
+ with open(category_name, "rb") as file:
58
+ prompt_yaml_file = yaml.full_load(file)
59
+
60
+ prompt_string = prompt_yaml_file["prompts"][prompt_name]
61
+ return PromptString(prompt_string)
62
+ else:
63
+ try:
64
+ return PROMPT_REGISTRY[category_name][prompt_name]
65
+ except Exception:
66
+ raise ValueError(
67
+ f"expected only a single `:` as separator between \
68
+ prompt category and name, but got `{prompt_id}` instead"
69
+ )
70
+
71
+
72
+ def load_prompt_list(
73
+ use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
74
+ ):
75
+ category_name, prompt_name = use_prompt.split(":")
76
+
77
+ if category_name == "promptsource":
78
+ from promptsource.templates import DatasetTemplates
79
+
80
+ if subset_name is None:
81
+ prompts = DatasetTemplates(dataset_name=dataset_name)
82
+ else:
83
+ prompts = DatasetTemplates(
84
+ dataset_name=dataset_name, subset_name=subset_name
85
+ )
86
+
87
+ prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
88
+
89
+ elif ".yaml" in category_name:
90
+ import yaml
91
+
92
+ if yaml_path is not None:
93
+ category_name = os.path.realpath(os.path.join(yaml_path, category_name))
94
+
95
+ with open(category_name, "rb") as file:
96
+ prompt_yaml_file = yaml.full_load(file)
97
+
98
+ prompt_list = utils.pattern_match(
99
+ prompt_name, prompt_yaml_file["prompts"].keys()
100
+ )
101
+
102
+ # category_name, *prompt_name = use_prompt.split(":")
103
+ # TODO allow to multiple prompt naming
104
+ # if len(prompt_name) > 1:
105
+ # prompt_list = []
106
+ # for prompt in prompt_name:
107
+ # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names))
108
+ # else:
109
+ # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names)
110
+ return [":".join([category_name, prompt]) for prompt in prompt_list]
111
+
112
+
113
+ class PromptString:
114
+ def __init__(self, prompt_string):
115
+ self.prompt_string = prompt_string
116
+
117
+ def apply(self, doc):
118
+ doc_to_text = self.prompt_string["doc_to_text"]
119
+ doc_to_target = self.prompt_string["doc_to_target"]
120
+
121
+ # TODO need a way to process doc_to_choice
122
+ if "doc_to_choice" in self.prompt_string:
123
+ raise NotImplementedError("Not yet implemented to accept doc_to_choice")
124
+
125
+ text_string = utils.apply_template(doc_to_text, doc)
126
+ target_string = utils.apply_template(doc_to_target, doc)
127
+
128
+ return [text_string, target_string]
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/README.md ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Tasks
3
+
4
+ A list of supported tasks and task groupings can be viewed with `lm-eval --tasks list`.
5
+
6
+ For more information, including a full list of task names and their precise meanings or sources, follow the links provided to the individual README.md files for each subfolder.
7
+
8
+ | Task Family | Description | Language(s) |
9
+ |--------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------|
10
+ | [aclue](aclue/README.md) | Tasks focusing on ancient Chinese language understanding and cultural aspects. | Ancient Chinese |
11
+ | [aexams](aexams/README.md) | Tasks in Arabic related to various academic exams covering a range of subjects. | Arabic |
12
+ | [agieval](agieval/README.md) | Tasks involving historical data or questions related to history and historical texts. | English, Chinese |
13
+ | [anli](anli/README.md) | Adversarial natural language inference tasks designed to test model robustness. | English |
14
+ | [arabic_leaderboard_complete](arabic_leaderboard_complete/README.md) | A full version of the tasks in the Open Arabic LLM Leaderboard, focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) |
15
+ | [arabic_leaderboard_light](arabic_leaderboard_light/README.md) | A light version of the tasks in the Open Arabic LLM Leaderboard (i.e., 10% samples of the test set in the original benchmarks), focusing on the evaluation of models that reflect the characteristics of Arabic language understanding and comprehension, culture, and heritage. Note that some of these tasks are machine-translated. | Arabic (Some MT) |
16
+ | [arabicmmlu](arabicmmlu/README.md) | Localized Arabic version of MMLU with multiple-choice questions from 40 subjects. | Arabic |
17
+ | [AraDICE](aradice/README.md) | A collection of multiple tasks carefully designed to evaluate dialectal and cultural capabilities in large language models (LLMs). | Arabic |
18
+ | [arc](arc/README.md) | Tasks involving complex reasoning over a diverse set of questions. | English |
19
+ | [arithmetic](arithmetic/README.md) | Tasks involving numerical computations and arithmetic reasoning. | English |
20
+ | [asdiv](asdiv/README.md) | Tasks involving arithmetic and mathematical reasoning challenges. | English |
21
+ | [babi](babi/README.md) | Tasks designed as question and answering challenges based on simulated stories. | English |
22
+ | [basque_bench](basque_bench/README.md) | Collection of tasks in Basque encompassing various evaluation areas. | Basque |
23
+ | [basqueglue](basqueglue/README.md) | Tasks designed to evaluate language understanding in Basque language. | Basque |
24
+ | [bbh](bbh/README.md) | Tasks focused on deep semantic understanding through hypothesization and reasoning. | English, German |
25
+ | [bbq](bbq/README.md) | A question-answering benchmark designed to measure social biases in language models across various demographic categories and contexts. | English |
26
+ | [belebele](belebele/README.md) | Language understanding tasks in a variety of languages and scripts. | Multiple (122 languages) |
27
+ | benchmarks | General benchmarking tasks that test a wide range of language understanding capabilities. | |
28
+ | [bertaqa](bertaqa/README.md) | Local Basque cultural trivia QA tests in English and Basque languages. | English, Basque, Basque (MT) |
29
+ | [bigbench](bigbench/README.md) | Broad tasks from the BIG-bench benchmark designed to push the boundaries of large models. | Multiple |
30
+ | [blimp](blimp/README.md) | Tasks testing grammatical phenomena to evaluate language model's linguistic capabilities. | English |
31
+ | [careqa](careqa/README.md) | Multiple choice and open-ended medical question answering based on the Spanish Specialised Healthcare Training (MIR) exams. | English, Spanish |
32
+ | [catalan_bench](catalan_bench/README.md) | Collection of tasks in Catalan encompassing various evaluation areas. | Catalan |
33
+ | [ceval](ceval/README.md) | Tasks that evaluate language understanding and reasoning in an educational context. | Chinese |
34
+ | [cmmlu](cmmlu/README.md) | Multi-subject multiple choice question tasks for comprehensive academic assessment. | Chinese |
35
+ | code_x_glue | Tasks that involve understanding and generating code across multiple programming languages. | Go, Java, JS, PHP, Python, Ruby |
36
+ | [commonsense_qa](commonsense_qa/README.md) | CommonsenseQA, a multiple-choice QA dataset for measuring commonsense knowledge. | English |
37
+ | [copal_id](copal_id/README.md) United States | Indonesian causal commonsense reasoning dataset that captures local nuances. | Indonesian |
38
+ | [coqa](coqa/README.md) | Conversational question answering tasks to test dialog understanding. | English |
39
+ | [crows_pairs](crows_pairs/README.md) | Tasks designed to test model biases in various sociodemographic groups. | English, French |
40
+ | csatqa | Tasks related to SAT and other standardized testing questions for academic assessment. | Korean |
41
+ | [drop](drop/README.md) | Tasks requiring numerical reasoning, reading comprehension, and question answering. | English |
42
+ | [eq_bench](eq_bench/README.md) | Tasks focused on equality and ethics in question answering and decision-making. | English |
43
+ | [eus_exams](eus_exams/README.md) | Tasks based on various professional and academic exams in the Basque language. | Basque |
44
+ | [eus_proficiency](eus_proficiency/README.md) | Tasks designed to test proficiency in the Basque language across various topics. | Basque |
45
+ | [eus_reading](eus_reading/README.md) | Reading comprehension tasks specifically designed for the Basque language. | Basque |
46
+ | [eus_trivia](eus_trivia/README.md) | Trivia and knowledge testing tasks in the Basque language. | Basque |
47
+ | [evalita-LLM](evalita-LLM/README.md) | A native Italian benchmark with diverse tasks formats and multiple prompts. | Italian |
48
+ | [fda](fda/README.md) | Tasks for extracting key-value pairs from FDA documents to test information extraction. | English |
49
+ | [fld](fld/README.md) | Tasks involving free-form and directed dialogue understanding. | English |
50
+ | [french_bench](french_bench/README.md) | Set of tasks designed to assess language model performance in French. | French |
51
+ | [galician_bench](galician_bench/README.md) | Collection of tasks in Galician encompassing various evaluation areas. | Galician |
52
+ | [global_mmlu](global_mmlu/README.md) | Collection of culturally sensitive and culturally agnostic MMLU tasks in 15 languages with human translations or post-edits. | Multiple (15 languages) |
53
+ | [glue](glue/README.md) | General Language Understanding Evaluation benchmark to test broad language abilities. | English |
54
+ | [gpqa](gpqa/README.md) | Tasks designed for general public question answering and knowledge verification. | English |
55
+ | [gsm8k](gsm8k/README.md) | A benchmark of grade school math problems aimed at evaluating reasoning capabilities. | English |
56
+ | [groundcocoa](groundcocoa/README.md) | A benchmark evaluating the conditional and compositional reasoning of language models using a grounding task. | English |
57
+ | [haerae](haerae/README.md) | Tasks focused on assessing detailed factual and historical knowledge. | Korean |
58
+ | [headqa](headqa/README.md) | A high-level education-based question answering dataset to test specialized knowledge. | Spanish, English |
59
+ | [hellaswag](hellaswag/README.md) | Tasks to predict the ending of stories or scenarios, testing comprehension and creativity. | English |
60
+ | [hendrycks_ethics](hendrycks_ethics/README.md) | Tasks designed to evaluate the ethical reasoning capabilities of models. | English |
61
+ | [hendrycks_math](hendrycks_math/README.md) | Mathematical problem-solving tasks to test numerical reasoning and problem-solving. | English |
62
+ | [histoires_morales](histoires_morales/README.md) | A dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | French (Some MT) |
63
+ | [hrm8k](hrm8k/README.md) | A challenging bilingual math reasoning benchmark for Korean and English. | Korean (Some MT), English (Some MT) |
64
+ | [humaneval](humaneval/README.md) | Code generation task that measure functional correctness for synthesizing programs from docstrings. | Python |
65
+ | [ifeval](ifeval/README.md) | Interactive fiction evaluation tasks for narrative understanding and reasoning. | English |
66
+ | [inverse_scaling](inverse_scaling/README.md) | Multiple-choice tasks from the Inverse Scaling Prize, designed to find settings where larger language models perform worse. | English |
67
+ | [japanese_leaderboard](japanese_leaderboard/README.md) | Japanese language understanding tasks to benchmark model performance on various linguistic aspects. | Japanese |
68
+ | [kbl](kbl/README.md) | Korean Benchmark for Legal Language Understanding. | Korean |
69
+ | [kmmlu](kmmlu/README.md) | Knowledge-based multi-subject multiple choice questions for academic evaluation. | Korean |
70
+ | [kobest](kobest/README.md) | A collection of tasks designed to evaluate understanding in Korean language. | Korean |
71
+ | [kormedmcqa](kormedmcqa/README.md) | Medical question answering tasks in Korean to test specialized domain knowledge. | Korean |
72
+ | [lambada](lambada/README.md) | Tasks designed to predict the endings of text passages, testing language prediction skills. | English |
73
+ | [lambada_cloze](lambada_cloze/README.md) | Cloze-style LAMBADA dataset. | English |
74
+ | [lambada_multilingual](lambada_multilingual/README.md) | Multilingual LAMBADA dataset. This is a legacy version of the multilingual dataset, and users should instead use `lambada_multilingual_stablelm`. | German, English, Spanish, French, Italian |
75
+ | [lambada_multilingual_stablelm](lambada_multilingual_stablelm/README.md) | Multilingual LAMBADA dataset. Users should prefer evaluating on this version of the multilingual dataset instead of on `lambada_multilingual`. | German, English, Spanish, French, Italian, Dutch, Portuguese |
76
+ | [leaderboard](leaderboard/README.md) | Task group used by Hugging Face's [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard). Those tasks are static and will not change through time | English |
77
+ | [lingoly](lingoly/README.md) | Challenging logical reasoning benchmark in low-resource languages with controls for memorization | English, Multilingual |
78
+ | [logiqa](logiqa/README.md) | Logical reasoning tasks requiring advanced inference and deduction. | English, Chinese |
79
+ | [logiqa2](logiqa2/README.md) | Large-scale logical reasoning dataset adapted from the Chinese Civil Service Examination. | English, Chinese |
80
+ | [mastermind](mastermind/README.md) | Reasoning benchmark based on the board game of Mastermind. | English |
81
+ | [mathqa](mathqa/README.md) | Question answering tasks involving mathematical reasoning and problem-solving. | English |
82
+ | [mbpp](mbpp/README.md) | A benchmark designed to measure the ability to synthesize short Python programs from natural language descriptions. | Python |
83
+ | [meddialog](meddialog/README.md) | Medical open-ended QA and Question Entailment stemming from the MedDialog dataset. | English |
84
+ | [medtext](medtext/README.md) | Medical open-ended QA from the MedText Clinical Notes dataset. | English |
85
+ | [mimic_repsum](mimic_repsum/README.md) | Medical report summarization from the MIMIC-III dataset. | English |
86
+ | [mc_taco](mc_taco/README.md) | Question-answer pairs that require temporal commonsense comprehension. | English |
87
+ | [med_concepts_qa](med_concepts_qa/README.md) | Benchmark for evaluating LLMs on their abilities to interpret medical codes and distinguish between medical concept. | English |
88
+ | [metabench](metabench/README.md) | Distilled versions of six popular benchmarks which are highly predictive of overall benchmark performance and of a single general ability latent trait. | English |
89
+ | [mediqa_qa2019](mediqa_qa2019/README.md) | Open-ended healthcare question answering benchmark from the MEDIQA 2019 challenge. | English |
90
+ | medmcqa | Medical multiple choice questions assessing detailed medical knowledge. | English |
91
+ | medqa | Multiple choice question answering based on the United States Medical License Exams. | |
92
+ | [meqsum](meqsum/README.md) | Healtcare Question Entailment benchmark from the MeqSum dataset. | |
93
+ | [mgsm](mgsm/README.md) | Benchmark of multilingual grade-school math problems. | Spanish, French, German, Russian, Chinese, Japanese, Thai, Swahili, Bengali, Telugu |
94
+ | [minerva_math](minerva_math/README.md) | Mathematics-focused tasks requiring numerical reasoning and problem-solving skills. | English |
95
+ | [mlqa](mlqa/README.md) | MultiLingual Question Answering benchmark dataset for evaluating cross-lingual question answering performance. | English, Arabic, German, Spanish, Hindi, Vietnamese, Simplified Chinese |
96
+ | [mmlu](mmlu/README.md) | Massive Multitask Language Understanding benchmark for broad domain language evaluation. Several variants are supported. | English |
97
+ | [mmlu_pro](mmlu_pro/README.md) | A refined set of MMLU, integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. | English |
98
+ | [mmlu-pro-plus](mmlu-pro-plus/README.md) | A new test set for evaluating shortcut learning and higher-order reasoning of LLMs. | English |
99
+ | [mmlu_prox](mmlu_prox/README.md) | A multilingual benchmark that extends MMLU-Pro to multiple typologically diverse languages with human validation. | English, Japanese, Chinese, Korean, French, German, Spanish, Portuguese, Swahili, Thai, Arabic, Hindi, Bengali |
100
+ | [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English |
101
+ | model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | |
102
+ | [moral_stories](moral_stories/README.md) | A crowd-sourced dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | English
103
+ | [mts_dialog](mts_dialog/README.md) | Open-ended healthcare QA from the MTS-Dialog dataset. | English |
104
+ | [mutual](mutual/README.md) | A retrieval-based dataset for multi-turn dialogue reasoning. | English |
105
+ | [nq_open](nq_open/README.md) | Open domain question answering tasks based on the Natural Questions dataset. | English |
106
+ | [okapi/arc_multilingual](okapi/arc_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) **Machine Translated.** |
107
+ | [okapi/hellaswag_multilingual](okapi/hellaswag_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (30 languages) **Machine Translated.** |
108
+ | okapi/mmlu_multilingual | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (34 languages) **Machine Translated.** |
109
+ | [okapi/truthfulqa_multilingual](okapi/truthfulqa_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) **Machine Translated.** |
110
+ | [olaph](olaph/README.md) | Open-ended medical factuality Question Answering from the OLAPH dataset. | English |
111
+ | [openbookqa](openbookqa/README.md) | Open-book question answering tasks that require external knowledge and reasoning. | English |
112
+ | [paloma](paloma/README.md) | Paloma is a comprehensive benchmark designed to evaluate open language models across a wide range of domains, ranging from niche artist communities to mental health forums on Reddit. | English |
113
+ | [paws-x](paws-x/README.md) | Paraphrase Adversaries from Word Scrambling, focusing on cross-lingual capabilities. | English, French, Spanish, German, Chinese, Japanese, Korean |
114
+ | [pile](pile/README.md) | Open source language modelling data set that consists of 22 smaller, high-quality datasets. | English |
115
+ | [pile_10k](pile_10k/README.md) | The first 10K elements of The Pile, useful for debugging models trained on it. | English |
116
+ | [piqa](piqa/README.md) | Physical Interaction Question Answering tasks to test physical commonsense reasoning. | English |
117
+ | [polemo2](polemo2/README.md) | Sentiment analysis and emotion detection tasks based on Polish language data. | Polish |
118
+ | [portuguese_bench](portuguese_bench/README.md) | Collection of tasks in European Portuguese encompassing various evaluation areas. | Portuguese |
119
+ | [prost](prost/README.md) | Tasks requiring understanding of professional standards and ethics in various domains. | English |
120
+ | [pubmedqa](pubmedqa/README.md) | Question answering tasks based on PubMed research articles for biomedical understanding. | English |
121
+ | [qa4mre](qa4mre/README.md) | Question Answering for Machine Reading Evaluation, assessing comprehension and reasoning. | English |
122
+ | [qasper](qasper/README.md) | Question Answering dataset based on academic papers, testing in-depth scientific knowledge. | English |
123
+ | [race](race/README.md) | Reading comprehension assessment tasks based on English exams in China. | English |
124
+ | realtoxicityprompts | Tasks to evaluate language models for generating text with potential toxicity. | |
125
+ | [ruler](ruler/README.md) | RULER is a benchmark for testing how well language models handle long pieces of text. Requires custom arg (see readme) | English |
126
+ | [sciq](sciq/README.md) | Science Question Answering tasks to assess understanding of scientific concepts. | English |
127
+ | [score](score/README.md) | Systematic consistency and robustness evaluation for LLMs on 3 datasets(MMLU-Pro, Agi Eval and MATH) | English |
128
+ | [scrolls](scrolls/README.md) | Tasks that involve long-form reading comprehension across various domains. | English |
129
+ | [simple_cooccurrence_bias](simple_cooccurrence_bias/README.md) | A metric that evaluates language models for biases based on stereotypical word associations and co-occurrences in text. | English |
130
+ | [siqa](siqa/README.md) | Social Interaction Question Answering to evaluate common sense and social reasoning. | English |
131
+ | [spanish_bench](spanish_bench/README.md) | Collection of tasks in Spanish encompassing various evaluation areas. | Spanish |
132
+ | [squad_completion](squad_completion/README.md) | A variant of the SQuAD question answering task designed for zero-shot evaluation of small LMs. | English |
133
+ | [squadv2](squadv2/README.md) | Stanford Question Answering Dataset version 2, a reading comprehension benchmark. | English |
134
+ | [storycloze](storycloze/README.md) | Tasks to predict story endings, focusing on narrative logic and coherence. | English |
135
+ | [super_glue](super_glue/README.md) | A suite of challenging tasks designed to test a range of language understanding skills. | English |
136
+ | [swag](swag/README.md) | Situations With Adversarial Generations, predicting the next event in videos. | English |
137
+ | [swde](swde/README.md) | Information extraction tasks from semi-structured web pages. | English |
138
+ | [tinyBenchmarks](tinyBenchmarks/README.md) | Evaluation of large language models with fewer examples using tiny versions of popular benchmarks. | English |
139
+ | [tmmluplus](tmmluplus/README.md) | An extended set of tasks under the TMMLU framework for broader academic assessments. | Traditional Chinese |
140
+ | [toxigen](toxigen/README.md) | Tasks designed to evaluate language models on their propensity to generate toxic content. | English |
141
+ | [translation](translation/README.md) | Tasks focused on evaluating the language translation capabilities of models. | Arabic, English, Spanish, Basque, Hindi, Indonesian, Burmese, Russian, Swahili, Telugu, Chinese |
142
+ | [triviaqa](triviaqa/README.md) | A large-scale dataset for trivia question answering to test general knowledge. | English |
143
+ | [truthfulqa](truthfulqa/README.md) | A QA task aimed at evaluating the truthfulness and factual accuracy of model responses. | English |
144
+ | [turkishmmlu](turkishmmlu/README.md) | A multiple-choice QA test modeled after MMLU, written in Turkish based on Turkish high-school level exams. | Turkish |
145
+ | [unitxt](unitxt/README.md) | A number of tasks implemented using the unitxt library for flexible, shareable, and reusable data preparation and evaluation for generative AI. | English |
146
+ | [unscramble](unscramble/README.md) | Tasks involving the rearrangement of scrambled sentences to test syntactic understanding. | English |
147
+ | [webqs](webqs/README.md) | Web-based question answering tasks designed to evaluate internet search and retrieval. | English |
148
+ | [wikitext](wikitext/README.md) | Tasks based on text from Wikipedia articles to assess language modeling and generation. | English |
149
+ | [winogender](winogender/README.md) | A diagnostic dataset that tests for gender bias in coreference resolution by measuring how models associate pronouns with different occupations. | English |
150
+ | [winogrande](winogrande/README.md) | A large-scale dataset for coreference resolution, inspired by the Winograd Schema Challenge. | English |
151
+ | [wmdp](wmdp/README.md) | A benchmark with the objective of minimizing performance, based on potentially-sensitive multiple-choice knowledge questions. | English |
152
+ | [wmt2016](wmt2016/README.md) | Tasks from the WMT 2016 shared task, focusing on translation between multiple languages. | English, Czech, German, Finnish, Russian, Romanian, Turkish |
153
+ | [wsc273](wsc273/README.md) | The Winograd Schema Challenge, a test of commonsense reasoning and coreference resolution. | English |
154
+ | [xcopa](xcopa/README.md) | Cross-lingual Choice of Plausible Alternatives, testing reasoning in multiple languages. | Estonian, Haitian, Indonesian, Italian, Quechua, Swahili, Tamil, Thai, Turkish, Vietnamese, Chinese |
155
+ | [xnli](xnli/README.md) | Cross-Lingual Natural Language Inference to test understanding across different languages. | Arabic, Bulgarian, German, Greek, English, Spanish, French, Hindi, Russian, Swahili, Thai, Turkish, Urdu, Vietnamese, Chinese |
156
+ | [xnli_eu](xnli_eu/README.md) | Cross-lingual Natural Language Inference tasks in Basque. | Basque |
157
+ | [xquad](xquad/README.md) | Cross-lingual Question Answering Dataset in multiple languages. | Arabic, German, Greek, English, Spanish, Hindi, Romanian, Russian, Thai, Turkish, Vietnamese, Chinese |
158
+ | [xstorycloze](xstorycloze/README.md) | Cross-lingual narrative understanding tasks to predict story endings in multiple languages. | Russian, Simplified Chinese, Spanish, Arabic, Hindi, Indonesian, Telugu, Swahili, Basque, Burmese |
159
+ | [xwinograd](xwinograd/README.md) | Cross-lingual Winograd schema tasks for coreference resolution in multiple languages. | English, French, Japanese, Portuguese, Russian, Chinese |
160
+
161
+ ## Multilingual Tasks
162
+ | Task Family | Description | Modality |
163
+ |------------------------------|---------------------------------------------------------------------------------------------------------|-------------|
164
+ | [chartqa](chartqa/README.md) | A benchmark for question answering about charts that requires both visual and logical reasoning. | Image, Text |
165
+ | [mmmu](mmmu/README.md) | Evaluate multimodal models on massive multi-discipline tasks demanding college-level subject knowledge. | Image, Text |
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/__init__.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import inspect
3
+ import logging
4
+ import os
5
+ from functools import partial
6
+ from typing import Dict, List, Mapping, Optional, Union
7
+
8
+ from lm_eval import utils
9
+ from lm_eval.api.group import ConfigurableGroup, GroupConfig
10
+ from lm_eval.api.task import ConfigurableTask, Task
11
+ from lm_eval.evaluator_utils import get_subtask_list
12
+
13
+
14
+ GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
15
+
16
+ eval_logger = logging.getLogger(__name__)
17
+
18
+
19
+ class TaskManager:
20
+ """TaskManager indexes all tasks from the default `lm_eval/tasks/`
21
+ and an optional directory if provided.
22
+
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ verbosity: Optional[str] = None,
28
+ include_path: Optional[Union[str, List]] = None,
29
+ include_defaults: bool = True,
30
+ metadata: Optional[dict] = None,
31
+ ) -> None:
32
+ if verbosity is not None:
33
+ utils.setup_logging(verbosity)
34
+ self.include_path = include_path
35
+ self.metadata = metadata
36
+ self._task_index = self.initialize_tasks(
37
+ include_path=include_path, include_defaults=include_defaults
38
+ )
39
+ self._all_tasks = sorted(list(self._task_index.keys()))
40
+
41
+ self._all_groups = sorted(
42
+ [x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
43
+ )
44
+ self._all_subtasks = sorted(
45
+ [
46
+ x
47
+ for x in self._all_tasks
48
+ if self._task_index[x]["type"] in ["task", "python_task"]
49
+ ]
50
+ )
51
+ self._all_tags = sorted(
52
+ [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
53
+ )
54
+
55
+ self.task_group_map = collections.defaultdict(list)
56
+
57
+ def initialize_tasks(
58
+ self,
59
+ include_path: Optional[Union[str, List]] = None,
60
+ include_defaults: bool = True,
61
+ ) -> dict[str, dict]:
62
+ """Creates a dictionary of tasks indexes.
63
+
64
+ :param include_path: Union[str, List] = None
65
+ An additional path to be searched for tasks recursively.
66
+ Can provide more than one such path as a list.
67
+ :param include_defaults: bool = True
68
+ If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
69
+ return
70
+ Dictionary of task names as key and task metadata
71
+ """
72
+ if include_defaults:
73
+ all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
74
+ else:
75
+ all_paths = []
76
+ if include_path is not None:
77
+ if isinstance(include_path, str):
78
+ include_path = [include_path]
79
+ all_paths.extend(include_path)
80
+
81
+ task_index = {}
82
+ for task_dir in all_paths:
83
+ tasks = self._get_task_and_group(task_dir)
84
+ task_index = {**tasks, **task_index}
85
+
86
+ return task_index
87
+
88
+ @property
89
+ def all_tasks(self):
90
+ return self._all_tasks
91
+
92
+ @property
93
+ def all_groups(self):
94
+ return self._all_groups
95
+
96
+ @property
97
+ def all_subtasks(self):
98
+ return self._all_subtasks
99
+
100
+ @property
101
+ def all_tags(self):
102
+ return self._all_tags
103
+
104
+ @property
105
+ def task_index(self):
106
+ return self._task_index
107
+
108
+ def list_all_tasks(
109
+ self, list_groups=True, list_tags=True, list_subtasks=True
110
+ ) -> str:
111
+ from pytablewriter import MarkdownTableWriter
112
+
113
+ def sanitize_path(path):
114
+ # don't print full path if we are within the lm_eval/tasks dir !
115
+ # if we aren't though, provide the full path.
116
+ if "lm_eval/tasks/" in path:
117
+ return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
118
+ else:
119
+ return path
120
+
121
+ group_table = MarkdownTableWriter()
122
+ group_table.headers = ["Group", "Config Location"]
123
+ gt_values = []
124
+ for g in self.all_groups:
125
+ path = self.task_index[g]["yaml_path"]
126
+ if path == -1:
127
+ path = "---"
128
+ else:
129
+ path = sanitize_path(path)
130
+ gt_values.append([g, path])
131
+ group_table.value_matrix = gt_values
132
+
133
+ tag_table = MarkdownTableWriter()
134
+ tag_table.headers = ["Tag"]
135
+ tag_table.value_matrix = [[t] for t in self.all_tags]
136
+
137
+ subtask_table = MarkdownTableWriter()
138
+ subtask_table.headers = ["Task", "Config Location", "Output Type"]
139
+ st_values = []
140
+ for t in self.all_subtasks:
141
+ path = self.task_index[t]["yaml_path"]
142
+
143
+ output_type = ""
144
+
145
+ # read the yaml file to determine the output type
146
+ if path != -1:
147
+ config = utils.load_yaml_config(path, mode="simple")
148
+ if "output_type" in config:
149
+ output_type = config["output_type"]
150
+ elif (
151
+ "include" in config
152
+ ): # if no output type, check if there is an include with an output type
153
+ include_path = path.split("/")[:-1] + config["include"]
154
+ include_config = utils.load_yaml_config(include_path, mode="simple")
155
+ if "output_type" in include_config:
156
+ output_type = include_config["output_type"]
157
+
158
+ if path == -1:
159
+ path = "---"
160
+ else:
161
+ path = sanitize_path(path)
162
+ st_values.append([t, path, output_type])
163
+ subtask_table.value_matrix = st_values
164
+
165
+ result = "\n"
166
+ if list_groups:
167
+ result += group_table.dumps() + "\n\n"
168
+ if list_tags:
169
+ result += tag_table.dumps() + "\n\n"
170
+ if list_subtasks:
171
+ result += subtask_table.dumps() + "\n\n"
172
+ return result
173
+
174
+ def match_tasks(self, task_list: list[str]) -> list[str]:
175
+ return utils.pattern_match(task_list, self.all_tasks)
176
+
177
+ def _name_is_registered(self, name: str) -> bool:
178
+ if name in self.all_tasks:
179
+ return True
180
+ return False
181
+
182
+ def _name_is_task(self, name: str) -> bool:
183
+ if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
184
+ return True
185
+ return False
186
+
187
+ def _name_is_tag(self, name: str) -> bool:
188
+ if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
189
+ return True
190
+ return False
191
+
192
+ def _name_is_group(self, name: str) -> bool:
193
+ if self._name_is_registered(name) and (
194
+ self.task_index[name]["type"] == "group"
195
+ ):
196
+ return True
197
+ return False
198
+
199
+ def _name_is_python_task(self, name: str) -> bool:
200
+ if self._name_is_registered(name) and (
201
+ self.task_index[name]["type"] == "python_task"
202
+ ):
203
+ return True
204
+ return False
205
+
206
+ def _config_is_task(self, config: dict) -> bool:
207
+ if ("task" in config) and isinstance(config["task"], str):
208
+ return True
209
+ return False
210
+
211
+ def _config_is_group(self, config: dict) -> bool:
212
+ if ("task" in config) and isinstance(config["task"], list):
213
+ return True
214
+ return False
215
+
216
+ def _config_is_python_task(self, config: dict) -> bool:
217
+ if "class" in config:
218
+ return True
219
+ return False
220
+
221
+ def _get_yaml_path(self, name: str):
222
+ if name not in self.task_index:
223
+ raise ValueError
224
+ return self.task_index[name]["yaml_path"]
225
+
226
+ def _get_config(self, name):
227
+ if name not in self.task_index:
228
+ raise ValueError
229
+ yaml_path = self._get_yaml_path(name)
230
+ if yaml_path == -1:
231
+ return {}
232
+ else:
233
+ return utils.load_yaml_config(yaml_path, mode="full")
234
+
235
+ def _get_tasklist(self, name):
236
+ if self._name_is_task(name):
237
+ raise ValueError
238
+ return self.task_index[name]["task"]
239
+
240
+ def _process_alias(self, config, group=None):
241
+ # If the group is not the same as the original
242
+ # group which the group alias was intended for,
243
+ # Set the group_alias to None instead.
244
+ if ("group_alias" in config) and ("group" in config) and group is not None:
245
+ if config["group"] != group:
246
+ config["group_alias"] = None
247
+ return config
248
+
249
+ def _class_has_config_in_constructor(self, cls):
250
+ constructor = getattr(cls, "__init__", None)
251
+ return (
252
+ "config" in inspect.signature(constructor).parameters
253
+ if constructor
254
+ else False
255
+ )
256
+
257
+ def _load_individual_task_or_group(
258
+ self,
259
+ name_or_config: Optional[Union[str, dict]] = None,
260
+ parent_name: Optional[str] = None,
261
+ update_config: Optional[dict] = None,
262
+ ) -> Mapping:
263
+ def _load_task(config, task):
264
+ if "include" in config:
265
+ config = {
266
+ **utils.load_yaml_config(
267
+ yaml_path=None,
268
+ yaml_config={"include": config.pop("include")},
269
+ mode="full",
270
+ ),
271
+ **config,
272
+ }
273
+ if self._config_is_python_task(config):
274
+ if self._class_has_config_in_constructor(config["class"]):
275
+ task_object = config["class"](config=config)
276
+ else:
277
+ task_object = config["class"]()
278
+ if isinstance(task_object, ConfigurableTask):
279
+ # very scuffed: set task name here. TODO: fixme?
280
+ task_object.config.task = task
281
+ else:
282
+ if self.metadata is not None:
283
+ config["metadata"] = config.get("metadata", {}) | self.metadata
284
+ else:
285
+ config["metadata"] = config.get("metadata", {})
286
+ task_object = ConfigurableTask(config=config)
287
+
288
+ return {task: task_object}
289
+
290
+ def _get_group_and_subtask_from_config(
291
+ config: dict,
292
+ ) -> tuple[ConfigurableGroup, list[str]]:
293
+ if self.metadata is not None:
294
+ config["metadata"] = config.get("metadata", {}) | self.metadata
295
+ group_name = ConfigurableGroup(config=config)
296
+ subtask_list = []
297
+ for task in group_name.config["task"]:
298
+ if isinstance(task, str) and self._name_is_tag(task):
299
+ subtask_list.extend(self._get_tasklist(task))
300
+ else:
301
+ subtask_list.append(task)
302
+ return group_name, subtask_list
303
+
304
+ def _process_group_config(
305
+ config: dict, update_config: dict = None
306
+ ) -> tuple[dict, dict]:
307
+ if update_config is not None:
308
+ config = {**config, **update_config}
309
+ _update_config = {
310
+ k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
311
+ }
312
+ if not bool(_update_config):
313
+ _update_config = None
314
+
315
+ group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
316
+ return group_config, _update_config
317
+
318
+ if isinstance(name_or_config, str):
319
+ if update_config is not None:
320
+ # Process name_or_config as a dict instead
321
+ name_or_config = {"task": name_or_config, **update_config}
322
+ elif self._name_is_task(name_or_config) or self._name_is_python_task(
323
+ name_or_config
324
+ ):
325
+ task_config = self._get_config(name_or_config)
326
+ return _load_task(task_config, task=name_or_config)
327
+ else:
328
+ subtask_list = self._get_tasklist(name_or_config)
329
+ if subtask_list == -1:
330
+ group_config = self._get_config(name_or_config)
331
+ group_config, update_config = _process_group_config(group_config)
332
+ group_name, subtask_list = _get_group_and_subtask_from_config(
333
+ group_config
334
+ )
335
+ else:
336
+ if self._name_is_tag(name_or_config):
337
+ fn = partial(
338
+ self._load_individual_task_or_group,
339
+ update_config=name_or_config
340
+ if isinstance(name_or_config, dict)
341
+ else None,
342
+ )
343
+ return dict(
344
+ collections.ChainMap(*map(fn, reversed(subtask_list)))
345
+ )
346
+ else:
347
+ group_name = ConfigurableGroup(
348
+ config={"group": name_or_config, "task": subtask_list}
349
+ )
350
+
351
+ if isinstance(name_or_config, dict):
352
+ if self._config_is_task(name_or_config):
353
+ name = name_or_config.pop("task")
354
+ if update_config is not None:
355
+ name_or_config = {**name_or_config, **update_config}
356
+ # If the name is registered as a group
357
+ if self._name_is_group(name):
358
+ group_config = self._get_config(name)
359
+
360
+ group_config, update_config = _process_group_config(
361
+ group_config, name_or_config
362
+ )
363
+ group_name, subtask_list = _get_group_and_subtask_from_config(
364
+ group_config
365
+ )
366
+ elif self._name_is_tag(name):
367
+ subtask_list = self._get_tasklist(name)
368
+ fn = partial(
369
+ self._load_individual_task_or_group,
370
+ update_config=name_or_config,
371
+ )
372
+ return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
373
+ else:
374
+ if self._name_is_registered(name):
375
+ base_task_config = self._get_config(name)
376
+
377
+ # Check if this is a duplicate.
378
+ if parent_name is not None:
379
+ num_duplicate = len(
380
+ list(
381
+ filter(
382
+ lambda x: x.startswith(name),
383
+ self.task_group_map[parent_name],
384
+ )
385
+ )
386
+ )
387
+ if num_duplicate > 0:
388
+ name = f"{name}-{num_duplicate}"
389
+ self.task_group_map[parent_name].append(name)
390
+
391
+ task_config = {
392
+ **base_task_config,
393
+ **name_or_config,
394
+ }
395
+ else:
396
+ task_config = name_or_config
397
+ return _load_task(task_config, task=name)
398
+ else:
399
+ group_config, update_config = _process_group_config(name_or_config)
400
+ group_name, subtask_list = _get_group_and_subtask_from_config(
401
+ group_config
402
+ )
403
+
404
+ fn = partial(
405
+ self._load_individual_task_or_group,
406
+ parent_name=group_name,
407
+ update_config=update_config,
408
+ )
409
+ return {
410
+ group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
411
+ }
412
+
413
+ def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
414
+ """Loads a dictionary of task objects from a list
415
+
416
+ :param task_list: Union[str, list] = None
417
+ Single string or list of string of task names to be loaded
418
+
419
+ :return
420
+ Dictionary of task objects
421
+ """
422
+ if isinstance(task_list, str):
423
+ task_list = [task_list]
424
+
425
+ all_loaded_tasks = dict(
426
+ collections.ChainMap(
427
+ *map(
428
+ lambda task: self._load_individual_task_or_group(task),
429
+ task_list,
430
+ )
431
+ )
432
+ )
433
+ return all_loaded_tasks
434
+
435
+ def load_config(self, config: Dict):
436
+ return self._load_individual_task_or_group(config)
437
+
438
+ def _get_task_and_group(self, task_dir: str):
439
+ """Creates a dictionary of tasks index with the following metadata,
440
+ - `type`, that can be either `task`, `python_task`, `group` or `tags`.
441
+ `task` refer to regular task configs, `python_task` are special
442
+ yaml files that only consists of `task` and `class` parameters.
443
+ `group` are group configs. `tags` are labels that can be assigned
444
+ to tasks to assist in sorting and calling tasks of certain themes.
445
+ - `yaml_path`, path to the yaml file. If the entry is a `group` that
446
+ was configured through a task config, the yaml_path will be -1
447
+ and all subtasks will be listed in `task` (see below)
448
+ - `task`, reserved for entries with `type` as `group`. This will list
449
+ all subtasks. When a group config is created (as opposed to task
450
+ config having `group` parameter set), this will be set to -1 to
451
+ avoid recursive indexing. The whole list of subtasks will be loaded
452
+ at evaluation.
453
+
454
+ :param task_dir: str
455
+ A directory to check for tasks
456
+
457
+ :return
458
+ Dictionary of task names as key and task metadata
459
+ """
460
+
461
+ def _populate_tags_and_groups(config, task, tasks_and_groups, print_info):
462
+ # TODO: remove group in next release
463
+ if "tag" in config:
464
+ attr_list = config["tag"]
465
+ if isinstance(attr_list, str):
466
+ attr_list = [attr_list]
467
+
468
+ for tag in attr_list:
469
+ if tag not in tasks_and_groups:
470
+ tasks_and_groups[tag] = {
471
+ "type": "tag",
472
+ "task": [task],
473
+ "yaml_path": -1,
474
+ }
475
+ elif tasks_and_groups[tag]["type"] != "tag":
476
+ eval_logger.info(
477
+ f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
478
+ "This may affect tasks you want to call."
479
+ )
480
+ break
481
+ else:
482
+ tasks_and_groups[tag]["task"].append(task)
483
+
484
+ # TODO: remove group in next release
485
+ print_info = True
486
+ ignore_dirs = [
487
+ "__pycache__",
488
+ ".ipynb_checkpoints",
489
+ ]
490
+ tasks_and_groups = collections.defaultdict()
491
+ for root, dirs, file_list in os.walk(task_dir):
492
+ dirs[:] = [d for d in dirs if d not in ignore_dirs]
493
+ for f in file_list:
494
+ if f.endswith(".yaml"):
495
+ yaml_path = os.path.join(root, f)
496
+ config = utils.load_yaml_config(yaml_path, mode="simple")
497
+ if self._config_is_python_task(config):
498
+ # This is a python class config
499
+ task = config["task"]
500
+ tasks_and_groups[task] = {
501
+ "type": "python_task",
502
+ "yaml_path": yaml_path,
503
+ }
504
+ _populate_tags_and_groups(
505
+ config, task, tasks_and_groups, print_info
506
+ )
507
+ elif self._config_is_group(config):
508
+ # This is a group config
509
+ tasks_and_groups[config["group"]] = {
510
+ "type": "group",
511
+ "task": -1, # This signals that
512
+ # we don't need to know
513
+ # the task list for indexing
514
+ # as it can be loaded
515
+ # when called.
516
+ "yaml_path": yaml_path,
517
+ }
518
+
519
+ # # Registered the level 1 tasks from a group config
520
+ # for config in config["task"]:
521
+ # if isinstance(config, dict) and self._config_is_task(config):
522
+ # task = config["task"]
523
+ # tasks_and_groups[task] = {
524
+ # "type": "task",
525
+ # "yaml_path": yaml_path,
526
+ # }
527
+
528
+ elif self._config_is_task(config):
529
+ # This is a task config
530
+ task = config["task"]
531
+ tasks_and_groups[task] = {
532
+ "type": "task",
533
+ "yaml_path": yaml_path,
534
+ }
535
+ _populate_tags_and_groups(
536
+ config, task, tasks_and_groups, print_info
537
+ )
538
+ else:
539
+ eval_logger.debug(f"File {f} in {root} could not be loaded")
540
+
541
+ return tasks_and_groups
542
+
543
+
544
+ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
545
+ if "task" in task_config:
546
+ return task_config["task"]
547
+ if "dataset_name" in task_config:
548
+ return "{dataset_path}_{dataset_name}".format(**task_config)
549
+ else:
550
+ return "{dataset_path}".format(**task_config)
551
+
552
+
553
+ def get_task_name_from_object(task_object):
554
+ if hasattr(task_object, "config"):
555
+ return task_object._config["task"]
556
+
557
+ # TODO: scrap this
558
+ # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
559
+ return (
560
+ task_object.EVAL_HARNESS_NAME
561
+ if hasattr(task_object, "EVAL_HARNESS_NAME")
562
+ else type(task_object).__name__
563
+ )
564
+
565
+
566
+ def _check_duplicates(task_dict: dict) -> None:
567
+ """helper function solely used in validating get_task_dict output.
568
+ Takes the output of lm_eval.evaluator_utils.get_subtask_list and
569
+ returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
570
+ "oversubscribed" to several disjoint groups.
571
+ """
572
+ subtask_names = []
573
+ for key, value in task_dict.items():
574
+ subtask_names.extend(value)
575
+
576
+ duplicate_tasks = {
577
+ task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
578
+ }
579
+
580
+ # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
581
+ competing_groups = [
582
+ group
583
+ for group in task_dict.keys()
584
+ if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
585
+ ]
586
+
587
+ if len(duplicate_tasks) > 0:
588
+ raise ValueError(
589
+ f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
590
+ )
591
+
592
+
593
+ def get_task_dict(
594
+ task_name_list: Union[str, List[Union[str, Dict, Task]]],
595
+ task_manager: Optional[TaskManager] = None,
596
+ ):
597
+ """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
598
+
599
+ :param task_name_list: List[Union[str, Dict, Task]]
600
+ Name of model or LM object, see lm_eval.models.get_model
601
+ :param task_manager: TaskManager = None
602
+ A TaskManager object that stores indexed tasks. If not set,
603
+ task_manager will load one. This should be set by the user
604
+ if there are additional paths that want to be included
605
+ via `include_path`
606
+
607
+ :return
608
+ Dictionary of task objects
609
+ """
610
+
611
+ task_name_from_string_dict = {}
612
+ task_name_from_config_dict = {}
613
+ task_name_from_object_dict = {}
614
+
615
+ if isinstance(task_name_list, str):
616
+ task_name_list = [task_name_list]
617
+ elif isinstance(task_name_list, list):
618
+ if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
619
+ raise TypeError(
620
+ "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
621
+ )
622
+ else:
623
+ raise TypeError(
624
+ f"Expected a 'str' or 'list' but received {type(task_name_list)}."
625
+ )
626
+
627
+ string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
628
+ others_task_name_list = [
629
+ task for task in task_name_list if not isinstance(task, str)
630
+ ]
631
+ if len(string_task_name_list) > 0:
632
+ if task_manager is None:
633
+ task_manager = TaskManager()
634
+
635
+ task_name_from_string_dict = task_manager.load_task_or_group(
636
+ string_task_name_list
637
+ )
638
+
639
+ for task_element in others_task_name_list:
640
+ if isinstance(task_element, dict):
641
+ task_name_from_config_dict = {
642
+ **task_name_from_config_dict,
643
+ **task_manager.load_config(config=task_element),
644
+ }
645
+
646
+ elif isinstance(task_element, Task):
647
+ task_name_from_object_dict = {
648
+ **task_name_from_object_dict,
649
+ get_task_name_from_object(task_element): task_element,
650
+ }
651
+
652
+ if not set(task_name_from_string_dict.keys()).isdisjoint(
653
+ set(task_name_from_object_dict.keys())
654
+ ):
655
+ raise ValueError
656
+
657
+ final_task_dict = {
658
+ **task_name_from_string_dict,
659
+ **task_name_from_config_dict,
660
+ **task_name_from_object_dict,
661
+ }
662
+
663
+ # behavior can get odd if one tries to invoke several groups that "compete" for the same task.
664
+ # (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
665
+ # and we'd be unsure which to use and report.)
666
+ # we explicitly check and error in this case.
667
+ _check_duplicates(get_subtask_list(final_task_dict))
668
+
669
+ return final_task_dict
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_5.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include: humaneval.yaml
2
+ task: humaneval_5
3
+ repeats: 5
4
+ metric_list:
5
+ - metric: !function utils.pass_at_k
6
+ aggregation: mean
7
+ higher_is_better: true
8
+ k: [1,2,3,4,5]
9
+ generation_kwargs:
10
+ until:
11
+ - "\nclass"
12
+ - "\ndef"
13
+ - "\n#"
14
+ - "\nif"
15
+ - "\nprint"
16
+ max_gen_toks: 1024
17
+ do_sample: true
18
+ temperature: 0.2
19
+ top_p: 0.95
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_5_instruct_noprefix.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include: humaneval_5.yaml
2
+ task: humaneval_5_instruct_noprefix
3
+ doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n```{{prompt}}"
4
+ gen_prefix: "```python\n"
5
+ generation_kwargs:
6
+ until:
7
+ - "\nassert"
8
+ - "\n# Test"
9
+ filter_list:
10
+ - name: "create_test"
11
+ filter:
12
+ - function: "custom"
13
+ filter_fn: !function utils.build_predictions_instruct
14
+ metadata:
15
+ version: 2.0
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_64.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include: humaneval.yaml
2
+ task: humaneval_64
3
+ repeats: 64
4
+ metric_list:
5
+ - metric: !function utils.pass_at_k
6
+ aggregation: mean
7
+ higher_is_better: true
8
+ k: [2,8,16,32,64]
9
+ generation_kwargs:
10
+ until:
11
+ - "\nclass"
12
+ - "\ndef"
13
+ - "\n#"
14
+ - "\nif"
15
+ - "\nprint"
16
+ max_gen_toks: 1024
17
+ do_sample: true
18
+ temperature: 0.2
19
+ top_p: 0.95
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_instruct.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include: humaneval.yaml
2
+ task: humaneval_instruct
3
+ doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n```{{prompt}}"
4
+ gen_prefix: "Here is the completed function:\n```python\n{{prompt}}\n"
5
+ filter_list:
6
+ - name: "create_test"
7
+ filter:
8
+ - function: "custom"
9
+ filter_fn: !function utils.build_predictions_instruct
10
+ metadata:
11
+ version: 2.0
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_instruct_noprefix.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include: humaneval.yaml
2
+ task: humaneval_instruct_noprefix
3
+ doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n```{{prompt}}```"
4
+ gen_prefix: "```python\n"
5
+ generation_kwargs:
6
+ until:
7
+ - "\nassert"
8
+ - "\n# Test"
9
+ filter_list:
10
+ - name: "create_test"
11
+ filter:
12
+ - function: "custom"
13
+ filter_fn: !function utils.build_predictions_instruct
14
+ metadata:
15
+ version: 2.0
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/sanitize_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import traceback
3
+
4
+ from typing import Dict, List, Optional, Set, Tuple
5
+
6
+ def refine_text(text: str) -> str:
7
+ text = text.replace("\t", " ")
8
+ text = text.replace("\r\n", "\n").replace("\r", "\n")
9
+ return text.strip() + "\n"
10
+
11
+ def syntax_check(code, verbose = False):
12
+ try:
13
+ ast.parse(code)
14
+ return True
15
+ except (SyntaxError, MemoryError):
16
+ if verbose:
17
+ traceback.print_exc()
18
+ return False
19
+
20
+ def extract_longest_valid_code(text: str) -> str:
21
+ lines = text.splitlines()
22
+
23
+ if len(lines) > 100:
24
+ lines = lines[:100]
25
+ max_valid_lines = 0
26
+ max_valid_snippet = ""
27
+
28
+ for i in range(len(lines)):
29
+ for j in range(i, len(lines)):
30
+ current_snippet = "\n".join(lines[i:j+1])
31
+ if syntax_check(current_snippet):
32
+ valid_line_count = sum(1 for line in lines[i:j+1] if line.strip())
33
+ if valid_line_count > max_valid_lines:
34
+ max_valid_lines = valid_line_count
35
+ max_valid_snippet = current_snippet
36
+
37
+ return max_valid_snippet
38
+
39
+ def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]:
40
+ name2deps = {}
41
+ for name, node in nodes:
42
+ deps = set()
43
+ stack = [node]
44
+ while stack:
45
+ current = stack.pop()
46
+ for child in ast.iter_child_nodes(current):
47
+ if isinstance(child, ast.Name):
48
+ deps.add(child.id)
49
+ elif isinstance(child, ast.Attribute):
50
+ deps.add(child.attr)
51
+ else:
52
+ stack.append(child)
53
+ name2deps[name] = deps
54
+ return name2deps
55
+
56
+ def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]:
57
+ visited = set()
58
+ to_visit = [entrypoint]
59
+
60
+ while to_visit:
61
+ current = to_visit.pop(0)
62
+ if current not in visited:
63
+ visited.add(current)
64
+ to_visit.extend(call_graph.get(current, set()) - visited)
65
+
66
+ return visited
67
+
68
+ def get_definition_name(node: ast.AST) -> Optional[str]:
69
+ if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
70
+ return node.name
71
+ elif isinstance(node, ast.Assign):
72
+ targets = node.targets
73
+ if targets and isinstance(targets[0], ast.Name):
74
+ return targets[0].id
75
+ return None
76
+
77
+ def has_return_statement(node: ast.AST) -> bool:
78
+ return any(isinstance(n, ast.Return) for n in ast.walk(node))
79
+
80
+ def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
81
+
82
+ text = refine_text(text)
83
+
84
+ # text = python_extract(text)
85
+
86
+ code = extract_longest_valid_code(text)
87
+ tree = ast.parse(code)
88
+
89
+ definitions = {}
90
+
91
+ imports = []
92
+
93
+ for node in tree.body:
94
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
95
+ imports.append(node)
96
+ elif isinstance(node, ast.ClassDef):
97
+ name = node.name
98
+ definitions[name] = ('class', node)
99
+ elif isinstance(node, ast.FunctionDef):
100
+ name = node.name
101
+ if has_return_statement(node):
102
+ definitions[name] = ('function', node)
103
+ elif isinstance(node, ast.Assign):
104
+ name = get_definition_name(node)
105
+ if name:
106
+ definitions[name] = ('variable', node)
107
+
108
+ if entrypoint:
109
+ name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
110
+ reachable = get_function_dependency(entrypoint, name2deps)
111
+
112
+ sanitized_output = []
113
+
114
+ for node in imports:
115
+ sanitized_output.append(ast.unparse(node))
116
+
117
+ for name, (_, node) in definitions.items():
118
+ if not entrypoint or name in reachable:
119
+ sanitized_output.append(ast.unparse(node))
120
+
121
+ return "\n".join(sanitized_output)
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate as hf_evaluate
2
+
3
+ from lm_eval.tasks.humaneval.sanitize_utils import sanitize
4
+
5
+
6
+ try:
7
+ compute_ = hf_evaluate.load("code_eval")
8
+ test_cases = ["assert add(2, 3)==5"]
9
+ candidates = [["def add(a,b): return a*b"]]
10
+ results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
11
+ except Exception as e:
12
+ raise e
13
+
14
+
15
+ def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
16
+ global compute_
17
+ assert k is not None
18
+ if isinstance(k, int):
19
+ k = [k]
20
+
21
+ processed_predictions = []
22
+ for preds in predictions:
23
+ processed_preds = []
24
+ for p in preds:
25
+ processed_preds.append(p.strip("```")[0] if "```" in p else p)
26
+ processed_predictions.append(processed_preds)
27
+
28
+ res = compute_.compute(
29
+ references=references,
30
+ predictions=predictions,
31
+ k=k,
32
+ )
33
+ return res[0]
34
+
35
+
36
+ def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
37
+ return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)]
38
+
39
+
40
+ def build_predictions_instruct(
41
+ resps: list[list[str]], docs: list[dict]
42
+ ) -> list[list[str]]:
43
+ return [
44
+ [
45
+ sanitize(
46
+ doc["prompt"] + "\n" + r.split('```python\n', 1)[-1].split('```')[0],
47
+ doc["entry_point"]
48
+ )
49
+ for r in resp
50
+ ]
51
+ for resp, doc in zip(resps, docs)
52
+ ]
Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/utils.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import fnmatch
3
+ import functools
4
+ import hashlib
5
+ import importlib.util
6
+ import inspect
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ from dataclasses import asdict, is_dataclass
12
+ from itertools import islice
13
+ from pathlib import Path
14
+ from typing import Any, Callable, Generator, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import yaml
18
+ from jinja2 import BaseLoader, Environment, StrictUndefined
19
+
20
+
21
+ SPACING = " " * 47
22
+
23
+ HIGHER_IS_BETTER_SYMBOLS = {
24
+ True: "↑",
25
+ False: "↓",
26
+ }
27
+
28
+
29
+ def setup_logging(verbosity=logging.INFO):
30
+ # Configure the root logger
31
+ class CustomFormatter(logging.Formatter):
32
+ def format(self, record):
33
+ if record.name.startswith("lm_eval."):
34
+ record.name = record.name[len("lm_eval.") :]
35
+ return super().format(record)
36
+
37
+ formatter = CustomFormatter(
38
+ "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
39
+ datefmt="%Y-%m-%d:%H:%M:%S",
40
+ )
41
+
42
+ log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity
43
+
44
+ level_map = {
45
+ "DEBUG": logging.DEBUG,
46
+ "INFO": logging.INFO,
47
+ "WARNING": logging.WARNING,
48
+ "ERROR": logging.ERROR,
49
+ "CRITICAL": logging.CRITICAL,
50
+ }
51
+
52
+ log_level = level_map.get(str(log_level).upper(), logging.INFO)
53
+
54
+ if not logging.root.handlers:
55
+ handler = logging.StreamHandler()
56
+ handler.setFormatter(formatter)
57
+
58
+ root_logger = logging.getLogger()
59
+ root_logger.addHandler(handler)
60
+ root_logger.setLevel(log_level)
61
+
62
+ if log_level == logging.DEBUG:
63
+ third_party_loggers = ["urllib3", "filelock", "fsspec"]
64
+ for logger_name in third_party_loggers:
65
+ logging.getLogger(logger_name).setLevel(logging.INFO)
66
+ else:
67
+ logging.getLogger().setLevel(log_level)
68
+
69
+
70
+ def hash_string(string: str) -> str:
71
+ return hashlib.sha256(string.encode("utf-8")).hexdigest()
72
+
73
+
74
+ def escaped_split(text, sep_char, maxsplit=-1):
75
+ """Split text into a list on occurrences of the given separation
76
+ character `sep_char`. The separation character may be escaped by a
77
+ backslash to avoid splitting at that location.
78
+
79
+ The separation character must be a string of size 1.
80
+
81
+ If `maxsplit` is given, at most `maxsplit` splits are done (thus,
82
+ the list will have at most `maxsplit + 1` elements). If `maxsplit`
83
+ is not specified or less than 0, then there is no limit on the
84
+ number of splits (all possible splits are made).
85
+ """
86
+ assert len(sep_char) == 1, (
87
+ "separation string must be a single character for escaped splitting"
88
+ )
89
+
90
+ if maxsplit == 0:
91
+ return text
92
+ maxsplit = max(0, maxsplit)
93
+
94
+ return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
95
+
96
+
97
+ def handle_arg_string(arg):
98
+ if arg.lower() == "true":
99
+ return True
100
+ elif arg.lower() == "false":
101
+ return False
102
+ elif arg.isnumeric():
103
+ return int(arg)
104
+ try:
105
+ return float(arg)
106
+ except ValueError:
107
+ return arg
108
+
109
+
110
+ def handle_non_serializable(o):
111
+ if isinstance(o, np.int64) or isinstance(o, np.int32):
112
+ return int(o)
113
+ elif isinstance(o, set):
114
+ return list(o)
115
+ else:
116
+ return str(o)
117
+
118
+
119
+ def sanitize_list(sub):
120
+ """
121
+ Takes possible nested list and recursively converts all inner component to strings
122
+ """
123
+ if isinstance(sub, list):
124
+ return [sanitize_list(item) for item in sub]
125
+ if isinstance(sub, tuple):
126
+ return tuple(sanitize_list(item) for item in sub)
127
+ else:
128
+ return str(sub)
129
+
130
+
131
+ def simple_parse_args_string(args_string: Optional[str]) -> dict:
132
+ """
133
+ Parses something like
134
+ args1=val1,arg2=val2
135
+ Into a dictionary
136
+ """
137
+ if args_string is None:
138
+ return {}
139
+ args_string = args_string.strip()
140
+ if not args_string:
141
+ return {}
142
+ arg_list = [arg for arg in args_string.split(",") if arg]
143
+ args_dict = {
144
+ kv[0]: handle_arg_string("=".join(kv[1:]))
145
+ for kv in [arg.split("=") for arg in arg_list]
146
+ }
147
+ return args_dict
148
+
149
+
150
+ def join_iters(iters):
151
+ for iter in iters:
152
+ yield from iter
153
+
154
+
155
+ def group(arr, fn):
156
+ res = collections.defaultdict(list)
157
+
158
+ for ob in arr:
159
+ res[fn(ob)].append(ob)
160
+
161
+ return list(res.values())
162
+
163
+
164
+ # Returns a list containing all values of the source_list that
165
+ # match at least one of the patterns
166
+ def pattern_match(patterns, source_list):
167
+ if isinstance(patterns, str):
168
+ patterns = [patterns]
169
+
170
+ task_names = set()
171
+ for pattern in patterns:
172
+ for matching in fnmatch.filter(source_list, pattern):
173
+ task_names.add(matching)
174
+ return sorted(list(task_names))
175
+
176
+
177
+ def softmax(x) -> np.ndarray:
178
+ """Compute softmax values for each sets of scores in x."""
179
+ e_x = np.exp(x - np.max(x))
180
+ return e_x / e_x.sum()
181
+
182
+
183
+ def general_detokenize(string) -> str:
184
+ string = string.replace(" n't", "n't")
185
+ string = string.replace(" )", ")")
186
+ string = string.replace("( ", "(")
187
+ string = string.replace('" ', '"')
188
+ string = string.replace(' "', '"')
189
+ string = re.sub(r" (['.,])", r"\1", string)
190
+ return string
191
+
192
+
193
+ def get_file_task_name(filename: str) -> str:
194
+ """
195
+ Given the sample results filenames, extracts and returns the task name.
196
+ """
197
+ return filename[filename.find("_") + 1 : filename.rfind("_")]
198
+
199
+
200
+ def get_file_datetime(filename: str) -> str:
201
+ """
202
+ Given the results and sample results filenames, extracts and returns the datetime.
203
+ """
204
+ return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")
205
+
206
+
207
+ def sanitize_model_name(model_name: str) -> str:
208
+ """
209
+ Given the model name, returns a sanitized version of it.
210
+ """
211
+ return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)
212
+
213
+
214
+ def sanitize_task_name(task_name: str) -> str:
215
+ """
216
+ Given the task name, returns a sanitized version of it.
217
+ """
218
+ return re.sub(r"\W", "_", task_name)
219
+
220
+
221
+ def get_latest_filename(filenames: List[str]) -> str:
222
+ """
223
+ Given a list of filenames, returns the filename with the latest datetime.
224
+ """
225
+ return max(filenames, key=lambda f: get_file_datetime(f))
226
+
227
+
228
+ def get_results_filenames(filenames: List[str]) -> List[str]:
229
+ """
230
+ Extracts filenames that correspond to aggregated results.
231
+ """
232
+ return [f for f in filenames if "/results_" in f and ".json" in f]
233
+
234
+
235
+ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
236
+ """
237
+ Extracts filenames that correspond to sample results.
238
+ """
239
+ return [f for f in filenames if "/samples_" in f and ".json" in f]
240
+
241
+
242
+ def get_rolling_token_windows(
243
+ token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
244
+ ) -> Generator[Tuple[List[int], List[int]], None, None]:
245
+ """
246
+ - context_len allows for a rolling window context, allowing each prediction window to potentially
247
+ condition on some context
248
+
249
+ :param token_list: list
250
+ List of tokens to be PREDICTED
251
+ :param max_seq_len: int
252
+ max_seq_len of model (or max_seq_len we want to use)
253
+ :param context_len: int
254
+ Amount of desired token context for prediction. Needs to be at least 1.
255
+ :param prefix_token: token
256
+ Dummy token like <eos> so the first token has something to condition on
257
+ :return: generator
258
+ Generator of tuples
259
+ (input_tokens, pred_tokens)
260
+ Note: Score only the last len(pred_tokens) logits of the LM
261
+ """
262
+ assert 1 <= context_len <= max_seq_len
263
+ if not token_list:
264
+ return
265
+ # +1 offset, going from input->preds
266
+ pred_len = max_seq_len - context_len + 1
267
+ predicted = 0
268
+
269
+ # Special handling for first window: predict all tokens
270
+ first_seq_len = min(max_seq_len, len(token_list))
271
+ yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
272
+ predicted += first_seq_len
273
+
274
+ while predicted < len(token_list):
275
+ window_pred_len = min(len(token_list) - predicted, pred_len)
276
+ window_end = predicted + window_pred_len
277
+
278
+ yield (
279
+ token_list[window_end - max_seq_len - 1 : window_end - 1],
280
+ token_list[window_end - window_pred_len : window_end],
281
+ )
282
+ predicted += window_pred_len
283
+
284
+
285
+ def make_disjoint_window(
286
+ pair: Tuple[List[int], List[int]],
287
+ ) -> Tuple[List[int], List[int]]:
288
+ """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
289
+ a, b = pair
290
+ return a[: len(a) - (len(b) - 1)], b
291
+
292
+
293
+ class EnhancedJSONEncoder(json.JSONEncoder):
294
+ """
295
+ Provides a proper json encoding for the loggers and trackers json dumps.
296
+ Notably manages the json encoding of dataclasses.
297
+ """
298
+
299
+ def default(self, o):
300
+ if is_dataclass(o):
301
+ return asdict(o)
302
+ return super().default(o)
303
+
304
+
305
+ class Reorderer:
306
+ def __init__(self, arr: List[Any], fn: Callable) -> None:
307
+ """Reorder an array according to some function
308
+
309
+ Args:
310
+ arr (List[Any]): The initial array
311
+ fn (Callable[[Any], Any]): A function to determine the priority of elements
312
+ """
313
+ self.size = len(arr)
314
+ arr = list(enumerate(arr))
315
+ arr = group(arr, lambda x: fn(x[1]))
316
+ # arr = [([y[0] for y in x], x[0][1]) for x in arr]
317
+ # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
318
+ arr = [([y[0]], x[0][1]) for x in arr for y in x]
319
+ arr.sort(key=lambda x: fn(x[1]))
320
+
321
+ self.arr = arr
322
+
323
+ def get_reordered(self):
324
+ """Gets the reordered array
325
+
326
+ Returns:
327
+ List[Any]: The reordered array
328
+ """
329
+ return [x[1] for x in self.arr]
330
+
331
+ def get_original(self, newarr):
332
+ """Restores the original order of a new array based on the old array's order
333
+
334
+ Args:
335
+ newarr (List[Any]): The array to be restored
336
+
337
+ Returns:
338
+ List[Any]: The array restored to the original order
339
+ """
340
+ res = [None] * self.size
341
+ cov = [False] * self.size
342
+
343
+ for (inds, _), v in zip(self.arr, newarr):
344
+ for ind in inds:
345
+ res[ind] = v
346
+ cov[ind] = True
347
+
348
+ assert all(cov)
349
+
350
+ return res
351
+
352
+
353
+ def make_table(result_dict, column: str = "results", sort_results: bool = False):
354
+ """Generate table of results."""
355
+ from pytablewriter import LatexTableWriter, MarkdownTableWriter
356
+
357
+ if column == "results":
358
+ column_name = "Tasks"
359
+ elif column == "groups":
360
+ column_name = "Groups"
361
+
362
+ all_headers = [
363
+ column_name,
364
+ "Version",
365
+ "Filter",
366
+ "n-shot",
367
+ "Metric",
368
+ "",
369
+ "Value",
370
+ "",
371
+ "Stderr",
372
+ ]
373
+
374
+ md_writer = MarkdownTableWriter()
375
+ latex_writer = LatexTableWriter()
376
+ md_writer.headers = all_headers
377
+ latex_writer.headers = all_headers
378
+
379
+ values = []
380
+
381
+ keys = result_dict[column].keys()
382
+ if sort_results:
383
+ # sort entries alphabetically by task or group name.
384
+ # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
385
+ # sorting here would mess that up
386
+ keys = sorted(keys)
387
+ for k in keys:
388
+ dic = result_dict[column][k]
389
+ version = result_dict["versions"].get(k, " N/A")
390
+ n = str(result_dict.get("n-shot", " ").get(k, " "))
391
+ higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
392
+
393
+ if "alias" in dic:
394
+ k = dic.pop("alias")
395
+
396
+ metric_items = dic.items()
397
+ metric_items = sorted(metric_items)
398
+
399
+ for (mf), v in metric_items:
400
+ m, _, f = mf.partition(",")
401
+ if m.endswith("_stderr"):
402
+ continue
403
+
404
+ hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
405
+
406
+ v = "%.4f" % v if isinstance(v, float) else v
407
+
408
+ if m + "_stderr" + "," + f in dic:
409
+ se = dic[m + "_stderr" + "," + f]
410
+ se = " N/A" if se == "N/A" else "%.4f" % se
411
+ values.append([k, version, f, n, m, hib, v, "±", se])
412
+ else:
413
+ values.append([k, version, f, n, m, hib, v, "", ""])
414
+ k = ""
415
+ version = ""
416
+ md_writer.value_matrix = values
417
+ latex_writer.value_matrix = values
418
+
419
+ # todo: make latex table look good
420
+ # print(latex_writer.dumps())
421
+
422
+ return md_writer.dumps()
423
+
424
+
425
+ def positional_deprecated(fn):
426
+ """
427
+ A decorator to nudge users into passing only keyword args (`kwargs`) to the
428
+ wrapped function, `fn`.
429
+ """
430
+
431
+ @functools.wraps(fn)
432
+ def _wrapper(*args, **kwargs):
433
+ if len(args) != 1 if inspect.ismethod(fn) else 0:
434
+ print(
435
+ f"WARNING: using {fn.__name__} with positional arguments is "
436
+ "deprecated and will be disallowed in a future version of "
437
+ "lm-evaluation-harness!"
438
+ )
439
+ return fn(*args, **kwargs)
440
+
441
+ return _wrapper
442
+
443
+
444
+ def ignore_constructor(loader, node):
445
+ return node
446
+
447
+
448
+ def import_function(loader: yaml.Loader, node, yaml_path: Path):
449
+ function_name = loader.construct_scalar(node)
450
+
451
+ *module_name, function_name = function_name.split(".")
452
+ if isinstance(module_name, list):
453
+ module_name = ".".join(module_name)
454
+ module_path = yaml_path.parent / f"{module_name}.py"
455
+
456
+ spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
457
+
458
+ if spec is None:
459
+ raise ImportError(f"Could not import module {module_name} from {module_path}.")
460
+ module = importlib.util.module_from_spec(spec)
461
+
462
+ if spec.loader is None:
463
+ raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
464
+ spec.loader.exec_module(module)
465
+
466
+ function = getattr(module, function_name)
467
+ return function
468
+
469
+
470
+ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
471
+ if mode == "simple":
472
+ constructor_fn = ignore_constructor
473
+ elif mode == "full":
474
+ if yaml_path is None:
475
+ raise ValueError("yaml_path must be provided if mode is 'full'.")
476
+ # Attach yaml_path to the import function so that it can be used later
477
+ constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
478
+
479
+ loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
480
+ # Add the import_function constructor to the YAML loader
481
+ yaml.add_constructor("!function", constructor_fn, Loader=loader)
482
+ if yaml_config is None:
483
+ with open(yaml_path, "rb") as file:
484
+ yaml_config = yaml.load(file, Loader=loader)
485
+
486
+ if yaml_dir is None:
487
+ yaml_dir = os.path.dirname(yaml_path)
488
+
489
+ assert yaml_dir is not None
490
+
491
+ if "include" in yaml_config:
492
+ include_path = yaml_config["include"]
493
+ del yaml_config["include"]
494
+
495
+ if isinstance(include_path, str):
496
+ include_path = [include_path]
497
+
498
+ # Load from the last one first
499
+ include_path.reverse()
500
+ final_yaml_config = {}
501
+ for path in include_path:
502
+ # Assumes that path is a full path.
503
+ # If not found, assume the included yaml
504
+ # is in the same dir as the original yaml
505
+ if not os.path.isfile(path):
506
+ path = os.path.join(yaml_dir, path)
507
+
508
+ try:
509
+ included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
510
+ final_yaml_config.update(included_yaml_config)
511
+ except Exception as ex:
512
+ # If failed to load, ignore
513
+ raise ex
514
+
515
+ final_yaml_config.update(yaml_config)
516
+ return final_yaml_config
517
+ return yaml_config
518
+
519
+
520
+ def regex_replace(string, pattern, repl, count: int = 0):
521
+ """Implements the `re.sub` function as a custom Jinja filter."""
522
+ return re.sub(pattern, repl, string, count=count)
523
+
524
+
525
+ env = Environment(
526
+ loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
527
+ )
528
+ env.filters["regex_replace"] = regex_replace
529
+
530
+
531
+ def apply_template(template: str, doc: dict) -> str:
532
+ rtemplate = env.from_string(template)
533
+ return rtemplate.render(**doc)
534
+
535
+
536
+ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
537
+ """
538
+ Method for creating a (potentially) sliced and limited
539
+ iterator from a raw document iterator. Used for splitting data
540
+ among ranks in multigpu setting or only pulling a sample of documents
541
+ """
542
+ return islice(raw_iterator, rank, limit, world_size)
543
+
544
+
545
+ def weighted_f1_score(items):
546
+ from sklearn.metrics import f1_score
547
+
548
+ unzipped_list = list(zip(*items))
549
+ golds = unzipped_list[0]
550
+ preds = unzipped_list[1]
551
+ fscore = f1_score(golds, preds, average="weighted")
552
+ return fscore
Prism/Dream/Dream_Baseline/eval_instruct/pyproject.toml ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=40.8.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "lm_eval"
7
+ version = "0.4.8"
8
+ authors = [
9
+ {name="EleutherAI", email="contact@eleuther.ai"}
10
+ ]
11
+ description = "A framework for evaluating language models"
12
+ readme = "README.md"
13
+ classifiers = [
14
+ "Development Status :: 3 - Alpha",
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Operating System :: OS Independent",
18
+ ]
19
+ requires-python = ">=3.9"
20
+ license = { "text" = "MIT" }
21
+ dependencies = [
22
+ "accelerate>=0.26.0",
23
+ "evaluate",
24
+ "datasets>=2.16.0",
25
+ "evaluate>=0.4.0",
26
+ "jsonlines",
27
+ "numexpr",
28
+ "peft>=0.2.0",
29
+ "pybind11>=2.6.2",
30
+ "pytablewriter",
31
+ "rouge-score>=0.0.4",
32
+ "sacrebleu>=1.5.0",
33
+ "scikit-learn>=0.24.1",
34
+ "sqlitedict",
35
+ "torch>=1.8",
36
+ "tqdm-multiprocess",
37
+ "transformers>=4.1",
38
+ "zstandard",
39
+ "dill",
40
+ "word2number",
41
+ "more_itertools",
42
+ ]
43
+
44
+ [tool.setuptools.packages.find]
45
+ include = ["lm_eval*"]
46
+
47
+ # required to include yaml files in pip installation
48
+ [tool.setuptools.package-data]
49
+ lm_eval = ["**/*.yaml", "tasks/**/*"]
50
+
51
+ [project.scripts]
52
+ lm-eval = "lm_eval.__main__:cli_evaluate"
53
+ lm_eval = "lm_eval.__main__:cli_evaluate"
54
+
55
+ [project.urls]
56
+ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
57
+ Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
58
+
59
+ [project.optional-dependencies]
60
+ api = ["requests", "aiohttp", "tenacity", "tqdm", "tiktoken"]
61
+ audiolm_qwen = ["librosa", "soundfile"]
62
+ deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
63
+ dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy", "unitxt"]
64
+ gptq = ["auto-gptq[triton]>=0.6.0"]
65
+ gptqmodel = ["gptqmodel>=1.0.9"]
66
+ hf_transfer = ["hf_transfer"]
67
+ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
68
+ ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
69
+ ipex = ["optimum"]
70
+ japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
71
+ longbench=["jeiba", "fuzzywuzzy", "rouge"]
72
+ mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
73
+ math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
74
+ multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
75
+ neuronx = ["optimum[neuronx]"]
76
+ optimum = ["optimum[openvino]"]
77
+ promptsource = ["promptsource>=0.2.3"]
78
+ ruler = ["nltk", "wonderwords", "scipy"]
79
+ sae_lens = ["sae_lens"]
80
+ sentencepiece = ["sentencepiece>=0.1.98"]
81
+ sparseml = ["sparseml-nightly[llm]>=1.8.0.20240404"]
82
+ sparsify = ["sparsify"]
83
+ testing = ["pytest", "pytest-cov", "pytest-xdist"]
84
+ vllm = ["vllm>=0.4.2"]
85
+ wandb = ["wandb>=0.16.3", "pandas", "numpy"]
86
+ zeno = ["pandas", "zeno-client"]
87
+ all = [
88
+ "lm_eval[api]",
89
+ "lm_eval[audiolm_qwen]",
90
+ "lm_eval[deepsparse]",
91
+ "lm_eval[dev]",
92
+ "lm_eval[gptq]",
93
+ "lm_eval[gptqmodel]",
94
+ "lm_eval[hf_transfer]",
95
+ "lm_eval[ibm_watsonx_ai]",
96
+ "lm_eval[ifeval]",
97
+ "lm_eval[ipex]",
98
+ "lm_eval[japanese_leaderboard]",
99
+ "lm_eval[longbench]",
100
+ "lm_eval[mamba]",
101
+ "lm_eval[math]",
102
+ "lm_eval[multilingual]",
103
+ "lm_eval[neuronx]",
104
+ "lm_eval[optimum]",
105
+ "lm_eval[promptsource]",
106
+ "lm_eval[ruler]",
107
+ "lm_eval[sae_lens]",
108
+ "lm_eval[sentencepiece]",
109
+ "lm_eval[sparseml]",
110
+ "lm_eval[sparsify]",
111
+ "lm_eval[testing]",
112
+ "lm_eval[vllm]",
113
+ "lm_eval[wandb]",
114
+ "lm_eval[zeno]",
115
+ ]
116
+
117
+ [tool.pymarkdown]
118
+ plugins.md013.enabled = false # line-length
119
+ plugins.md024.allow_different_nesting = true # no-duplicate-headers
120
+ plugins.md025.enabled = false # single-header
121
+ plugins.md028.enabled = false # no-blanks-blockquote
122
+ plugins.md029.allow_extended_start_values = true # ol-prefix
123
+ plugins.md034.enabled = false # no-bare-urls
124
+
125
+ [tool.ruff.lint]
126
+ extend-select = ["I"]
127
+
128
+ [tool.ruff.lint.isort]
129
+ lines-after-imports = 2
130
+ known-first-party = ["lm_eval"]
131
+
132
+ [tool.ruff.lint.extend-per-file-ignores]
133
+ "__init__.py" = ["F401","F402","F403"]
134
+ "utils.py" = ["F401"]