diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7c148d9fbec8be41fd89a01aa8590deabd2c4cad --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/README.md @@ -0,0 +1,59 @@ +# GPQA + +### Paper + +Title: GPQA: A Graduate-Level Google-Proof Q&A Benchmark + +Abstract: https://arxiv.org/abs/2311.12022 + +We present GPQA, a challenging dataset of 448 multiple-choice questions written by domain experts in biology, physics, and chemistry. We ensure that the questions are high-quality and extremely difficult: experts who have or are pursuing PhDs in the corresponding domains reach 65% accuracy (74% when discounting clear mistakes the experts identified in retrospect), while highly skilled non-expert validators only reach 34% accuracy, despite spending on average over 30 minutes with unrestricted access to the web (i.e., the questions are “Google-proof”). The questions are also difficult for state-of-the-art AI systems, with our strongest GPT-4–based baseline achieving 39% accuracy. If we are to use future AI systems to help us answer very hard questions—for example, when developing new scientific knowledge—we need to develop *scalable oversight* methods that enable humans to supervise their outputs, which may be difficult even if the supervisors are themselves skilled and knowledgeable. The difficulty of GPQA both for skilled non-experts and frontier AI systems should enable realistic scalable oversight experiments, which we hope can help devise ways for human experts to reliably get truthful information from AI systems that surpass human capabilities. + +Homepage: `https://github.com/idavidrein/gpqa/tree/main` + +### Citation + +``` +@misc{rein2023gpqa, + title={GPQA: A Graduate-Level Google-Proof Q&A Benchmark}, + author={David Rein and Betty Li Hou and Asa Cooper Stickland and Jackson Petty and Richard Yuanzhe Pang and Julien Dirani and Julian Michael and Samuel R. Bowman}, + year={2023}, + eprint={2311.12022}, + archivePrefix={arXiv}, + primaryClass={cs.AI} +} +``` + +This dataset is gated, so you will have to accept the terms of use at https://huggingface.co/datasets/Idavidrein/gpqa and login via `huggingface-cli login` using your HF Hub token before running this task. + +### Groups, Tags, and Tasks + +#### Groups + +None + +#### Tags + +* `gpqa`: runs all GPQA variants. + +#### Tasks + +* `gpqa_{main, diamond, extended}_zeroshot` +* `gpqa_{main, diamond, extended}_n_shot` +* `gpqa_{main, diamond, extended}_generative_n_shot` +* `gpqa_{main, diamond, extended}_cot_zeroshot` +* `gpqa_{main, diamond, extended}_cot_n_shot` + +### Checklist + +For adding novel benchmarks/datasets to the library: + +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: + +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/_generate_configs.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/_generate_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..73ccb876a449a1e8eda5984d977194f6b0c064d9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/_generate_configs.py @@ -0,0 +1,26 @@ +import yaml +from tqdm import tqdm + + +def main() -> None: + subset = ["extended", "diamond", "main"] + setting = "cot_n_shot" + for task in tqdm(subset): + file_name = f"gpqa_{task}_{setting}.yaml" + try: + with open(f"{file_name}", "w") as f: + f.write("# Generated by _generate_configs.py\n") + yaml.dump( + { + "include": f"_gpqa_{setting}_yaml", + "task": f"gpqa_{task}_{setting}", + "dataset_name": f"gpqa_{task}", + }, + f, + ) + except FileExistsError: + pass + + +if __name__ == "__main__": + main() diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/_gpqa_cot_n_shot_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/_gpqa_cot_n_shot_yaml new file mode 100644 index 0000000000000000000000000000000000000000..97c0603bcc94f0c689269ea9859b62bdfab7644e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/_gpqa_cot_n_shot_yaml @@ -0,0 +1,38 @@ +dataset_path: Idavidrein/gpqa +tag: gpqa +output_type: generate_until +process_docs: !function utils.process_docs +training_split: train +# Because huggingface dataset only has train split +validation_split: train +test_split: null +description: "Here are some example questions from experts. Answer the final question yourself, following the format of the previous questions exactly.\n" +doc_to_text: "Question: {{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nLet's think step by step: " +doc_to_target: answer +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "(?<=The answer is )(.*)(?=.)" + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "multi_choice_regex" + group_select: -1 + ignore_case: true + ignore_punctuation: true + regex_pattern: "(\\([A-Z]\\))" + - function: "take_first" +generation_kwargs: + until: + - "" + do_sample: false + temperature: 0.0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 2.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_diamond_cot_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_diamond_cot_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..24e5f4f90f1f770f9f792e4aeef51e08d3aa08d9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_diamond_cot_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_diamond +include: _gpqa_cot_n_shot_yaml +task: gpqa_diamond_cot_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_extended_cot_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_extended_cot_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..002ede9a82110e3679bf3e1e958ded4342e408e3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_extended_cot_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_extended +include: _gpqa_cot_n_shot_yaml +task: gpqa_extended_cot_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_main_cot_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_main_cot_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..916b6ea06a2e22042344b668191adbb3c91c4e75 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/gpqa_main_cot_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_main +include: _gpqa_cot_n_shot_yaml +task: gpqa_main_cot_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96bcd52b140fd0a5896f55c0a52ea2fd5453fd53 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_n_shot/utils.py @@ -0,0 +1,39 @@ +import random +import re + +import datasets + + +def preprocess(text): + if text is None: + return " " + text = text.strip() + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + choices = [ + preprocess(doc["Incorrect Answer 1"]), + preprocess(doc["Incorrect Answer 2"]), + preprocess(doc["Incorrect Answer 3"]), + preprocess(doc["Correct Answer"]), + ] + + random.shuffle(choices) + correct_answer_index = choices.index(preprocess(doc["Correct Answer"])) + + out_doc = { + "choice1": choices[0], + "choice2": choices[1], + "choice3": choices[2], + "choice4": choices[3], + "choices": [choices[0], choices[1], choices[2], choices[3]], + "answer": f"({chr(65 + correct_answer_index)})", + } + return out_doc + + return dataset.map(_process_doc) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/_generate_configs.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/_generate_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bda00784cc2fa26b5f0d488cf7b6aea37243353d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/_generate_configs.py @@ -0,0 +1,26 @@ +import yaml +from tqdm import tqdm + + +def main() -> None: + subset = ["extended", "diamond", "main"] + setting = "cot_zeroshot" + for task in tqdm(subset): + file_name = f"gpqa_{task}_{setting}.yaml" + try: + with open(f"{file_name}", "w") as f: + f.write("# Generated by _generate_configs.py\n") + yaml.dump( + { + "include": f"_gpqa_{setting}_yaml", + "task": f"gpqa_{task}_{setting}", + "dataset_name": f"gpqa_{task}", + }, + f, + ) + except FileExistsError: + pass + + +if __name__ == "__main__": + main() diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c487a8c4a3e3806bfa265fa7dc7a3f897ddedff --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/_gpqa_cot_zeroshot_yaml @@ -0,0 +1,38 @@ +dataset_path: Idavidrein/gpqa +tag: gpqa +output_type: generate_until +process_docs: !function utils.process_docs +training_split: train +# Because huggingface dataset only has train split +validation_split: train +test_split: null +doc_to_text: "What is the correct answer to this question:{{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nLet's think step by step: " +doc_to_target: answer +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "(?<=The answer is )(.*)(?=.)" + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "multi_choice_regex" + group_select: -1 + ignore_case: true + ignore_punctuation: true + regex_pattern: "(\\([A-Z]\\))" + - function: "take_first" +generation_kwargs: + until: + - "" + do_sample: false + temperature: 0.0 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_diamond_cot_zeroshot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_diamond_cot_zeroshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6a840fa1815096f5fa180ed06223e3523a06214 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_diamond_cot_zeroshot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_diamond +include: _gpqa_cot_zeroshot_yaml +task: gpqa_diamond_cot_zeroshot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_extended_cot_zeroshot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_extended_cot_zeroshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f542a6148f231e2d7e7e2a5a3437047459e3856 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_extended_cot_zeroshot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_extended +include: _gpqa_cot_zeroshot_yaml +task: gpqa_extended_cot_zeroshot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_main_cot_zeroshot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_main_cot_zeroshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c14604854294c4551e2602e573488c6a7fef254 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/gpqa_main_cot_zeroshot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_main +include: _gpqa_cot_zeroshot_yaml +task: gpqa_main_cot_zeroshot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96bcd52b140fd0a5896f55c0a52ea2fd5453fd53 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/cot_zeroshot/utils.py @@ -0,0 +1,39 @@ +import random +import re + +import datasets + + +def preprocess(text): + if text is None: + return " " + text = text.strip() + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + choices = [ + preprocess(doc["Incorrect Answer 1"]), + preprocess(doc["Incorrect Answer 2"]), + preprocess(doc["Incorrect Answer 3"]), + preprocess(doc["Correct Answer"]), + ] + + random.shuffle(choices) + correct_answer_index = choices.index(preprocess(doc["Correct Answer"])) + + out_doc = { + "choice1": choices[0], + "choice2": choices[1], + "choice3": choices[2], + "choice4": choices[3], + "choices": [choices[0], choices[1], choices[2], choices[3]], + "answer": f"({chr(65 + correct_answer_index)})", + } + return out_doc + + return dataset.map(_process_doc) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/_generate_configs.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/_generate_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c011ea02d25ca1d3550210f4a4644c97fa52c2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/_generate_configs.py @@ -0,0 +1,26 @@ +import yaml +from tqdm import tqdm + + +def main() -> None: + subset = ["extended", "diamond", "main"] + setting = "generative_n_shot" + for task in tqdm(subset): + file_name = f"gpqa_{task}_{setting}.yaml" + try: + with open(f"{file_name}", "w") as f: + f.write("# Generated by _generate_configs.py\n") + yaml.dump( + { + "include": f"_gpqa_{setting}_yaml", + "task": f"gpqa_{task}_{setting}", + "dataset_name": f"gpqa_{task}", + }, + f, + ) + except FileExistsError: + pass + + +if __name__ == "__main__": + main() diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/_gpqa_generative_n_shot_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/_gpqa_generative_n_shot_yaml new file mode 100644 index 0000000000000000000000000000000000000000..f43a9a414cb4e53e7d5e83787ae6c1e5de109111 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/_gpqa_generative_n_shot_yaml @@ -0,0 +1,39 @@ +dataset_path: Idavidrein/gpqa +tag: gpqa +output_type: generate_until +process_docs: !function utils.process_docs +training_split: train +# Because huggingface dataset only has train split +validation_split: train +test_split: null +description: "Here are some example questions from experts. Answer the final question yourself, following the format of the previous questions exactly.\n" +doc_to_text: "Question: {{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nAnswer:" +doc_to_target: answer +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "(?<=The answer is )(.*)(?=.)" + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "multi_choice_regex" + group_select: -1 + ignore_case: true + ignore_punctuation: true + regex_pattern: "(\\([A-Z]\\))" + - function: "take_first" +generation_kwargs: + until: + - "" + - "Question:" + - "<|im_end|>" + temperature: 0.0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 2.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_diamond_generative_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_diamond_generative_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a42094e8ba8ef6037820255b74a8830d550b8a9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_diamond_generative_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_diamond +include: _gpqa_generative_n_shot_yaml +task: gpqa_diamond_generative_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_extended_generative_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_extended_generative_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc40c2d97684c50b3992f5adf894ebe0c138b4ae --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_extended_generative_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_extended +include: _gpqa_generative_n_shot_yaml +task: gpqa_extended_generative_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_main_generative_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_main_generative_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..865f3cb5efa3d4b8641843cfde7db3c95bd8b8b3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/gpqa_main_generative_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_main +include: _gpqa_generative_n_shot_yaml +task: gpqa_main_generative_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96bcd52b140fd0a5896f55c0a52ea2fd5453fd53 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/generative/utils.py @@ -0,0 +1,39 @@ +import random +import re + +import datasets + + +def preprocess(text): + if text is None: + return " " + text = text.strip() + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + choices = [ + preprocess(doc["Incorrect Answer 1"]), + preprocess(doc["Incorrect Answer 2"]), + preprocess(doc["Incorrect Answer 3"]), + preprocess(doc["Correct Answer"]), + ] + + random.shuffle(choices) + correct_answer_index = choices.index(preprocess(doc["Correct Answer"])) + + out_doc = { + "choice1": choices[0], + "choice2": choices[1], + "choice3": choices[2], + "choice4": choices[3], + "choices": [choices[0], choices[1], choices[2], choices[3]], + "answer": f"({chr(65 + correct_answer_index)})", + } + return out_doc + + return dataset.map(_process_doc) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/_generate_configs.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/_generate_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..c01f208e767cb813e6d2116caf74c3d0b2fccfb3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/_generate_configs.py @@ -0,0 +1,26 @@ +import yaml +from tqdm import tqdm + + +def main() -> None: + subset = ["extended", "diamond", "main"] + + for task in tqdm(subset): + file_name = f"gpqa_{task}_n_shot.yaml" + try: + with open(f"{file_name}", "w") as f: + f.write("# Generated by _generate_configs.py\n") + yaml.dump( + { + "include": "_gpqa_n_shot_yaml", + "task": f"gpqa_{task}_n_shot", + "dataset_name": f"gpqa_{task}", + }, + f, + ) + except FileExistsError: + pass + + +if __name__ == "__main__": + main() diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/_gpqa_n_shot_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/_gpqa_n_shot_yaml new file mode 100644 index 0000000000000000000000000000000000000000..8406f8aabfa9d10eec18ef7a8565b6393a0bfc03 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/_gpqa_n_shot_yaml @@ -0,0 +1,21 @@ +dataset_path: Idavidrein/gpqa +tag: gpqa +output_type: multiple_choice +process_docs: !function utils.process_docs +training_split: train +# Because huggingface dataset only has train split +validation_split: train +test_split: null +description: "Here are some example questions from experts. Answer the final question yourself, following the format of the previous questions exactly.\n" +doc_to_text: "Question: {{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nAnswer:" +doc_to_target: answer +doc_to_choice: ["(A)", "(B)", "(C)", "(D)"] +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - metric: acc_norm + aggregation: mean + higher_is_better: true +metadata: + version: 2.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_diamond_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_diamond_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3043a7e53647ff72d535abc113dfccebaa1bd43c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_diamond_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_diamond +include: _gpqa_n_shot_yaml +task: gpqa_diamond_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_extended_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_extended_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d16b505b355bccb3d6fd70eb16b307c12d06a09 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_extended_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_extended +include: _gpqa_n_shot_yaml +task: gpqa_extended_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_main_n_shot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_main_n_shot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e5f3e9532ab41c0158409e6afb47393806c4177 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/gpqa_main_n_shot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_main +include: _gpqa_n_shot_yaml +task: gpqa_main_n_shot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b886d2879216094214ce534438e4db0c5e60f8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/n_shot/utils.py @@ -0,0 +1,41 @@ +import random +import re + +import datasets + + +def preprocess(text): + if text is None: + return " " + text = text.strip() + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +rng = random.Random(42) + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + choices = [ + preprocess(doc["Incorrect Answer 1"]), + preprocess(doc["Incorrect Answer 2"]), + preprocess(doc["Incorrect Answer 3"]), + preprocess(doc["Correct Answer"]), + ] + + rng.shuffle(choices) + correct_answer_index = choices.index(preprocess(doc["Correct Answer"])) + + out_doc = { + "choice1": choices[0], + "choice2": choices[1], + "choice3": choices[2], + "choice4": choices[3], + "answer": f"({chr(65 + correct_answer_index)})", + } + return out_doc + + return dataset.map(_process_doc) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/_generate_configs.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/_generate_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..79afbd6f1d8d4b2eb54455d734f6245357580bd3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/_generate_configs.py @@ -0,0 +1,26 @@ +import yaml +from tqdm import tqdm + + +def main() -> None: + subset = ["extended", "diamond", "main"] + setting = "zeroshot" + for task in tqdm(subset): + file_name = f"gpqa_{task}_{setting}.yaml" + try: + with open(f"{file_name}", "w") as f: + f.write("# Generated by _generate_configs.py\n") + yaml.dump( + { + "include": f"_gpqa_{setting}_yaml", + "task": f"gpqa_{task}_{setting}", + "dataset_name": f"gpqa_{task}", + }, + f, + ) + except FileExistsError: + pass + + +if __name__ == "__main__": + main() diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/_gpqa_zeroshot_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/_gpqa_zeroshot_yaml new file mode 100644 index 0000000000000000000000000000000000000000..500f1921bec3db0d1282b8501b7a0841ebbb79c4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/_gpqa_zeroshot_yaml @@ -0,0 +1,21 @@ +dataset_path: Idavidrein/gpqa +tag: gpqa +output_type: multiple_choice +process_docs: !function utils.process_docs +training_split: train +# Because huggingface dataset only has train split +validation_split: train +test_split: null +doc_to_text: "What is the correct answer to this question:{{Question}}\nChoices:\n(A) {{choice1}}\n(B) {{choice2}}\n(C) {{choice3}}\n(D) {{choice4}}\nAnswer:" +doc_to_target: answer +doc_to_choice: ["(A)", "(B)", "(C)", "(D)"] +num_fewshot: 0 +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - metric: acc_norm + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_diamond_zeroshot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_diamond_zeroshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3a7921c30b3ff09e82aacb4c0e915010f698966 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_diamond_zeroshot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_diamond +include: _gpqa_zeroshot_yaml +task: gpqa_diamond_zeroshot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_extended_zeroshot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_extended_zeroshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e7347f11154351ad4560200a3f3bf54106a1a8f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_extended_zeroshot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_extended +include: _gpqa_zeroshot_yaml +task: gpqa_extended_zeroshot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_main_zeroshot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_main_zeroshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a8d7fb59025d148130f2a468cb1bbdfad959102 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/gpqa_main_zeroshot.yaml @@ -0,0 +1,4 @@ +# Generated by _generate_configs.py +dataset_name: gpqa_main +include: _gpqa_zeroshot_yaml +task: gpqa_main_zeroshot diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c2317e02efd132aea27ec8c8fad284df55ccd382 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gpqa/zeroshot/utils.py @@ -0,0 +1,38 @@ +import random +import re + +import datasets + + +def preprocess(text): + if text is None: + return " " + text = text.strip() + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc): + choices = [ + preprocess(doc["Incorrect Answer 1"]), + preprocess(doc["Incorrect Answer 2"]), + preprocess(doc["Incorrect Answer 3"]), + preprocess(doc["Correct Answer"]), + ] + + random.shuffle(choices) + correct_answer_index = choices.index(preprocess(doc["Correct Answer"])) + + out_doc = { + "choice1": choices[0], + "choice2": choices[1], + "choice3": choices[2], + "choice4": choices[3], + "answer": f"({chr(65 + correct_answer_index)})", + } + return out_doc + + return dataset.map(_process_doc) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1556151f821f526cf57388f15bb5c867af904a15 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/README.md @@ -0,0 +1,62 @@ +# GSM8k + +## Paper +Training Verifiers to Solve Math Word Problems +https://arxiv.org/abs/2110.14168 + +State-of-the-art language models can match human performance on many tasks, but +they still struggle to robustly perform multi-step mathematical reasoning. To +diagnose the failures of current models and support research, we introduce GSM8K, +a dataset of 8.5K high quality linguistically diverse grade school math word problems. +We find that even the largest transformer models fail to achieve high test performance, +despite the conceptual simplicity of this problem distribution. + +NOTE: See the official implementation of the task: + https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py +for how to make use of the dataset's calculator annotations in your language +model's sample/generation function. + +Homepage: https://github.com/openai/grade-school-math + + +## Citation +``` +@misc{cobbe2021training, + title={Training Verifiers to Solve Math Word Problems}, + author={Karl Cobbe and Vineet Kosaraju and Mohammad Bavarian and Jacob Hilton and Reiichiro Nakano and Christopher Hesse and John Schulman}, + year={2021}, + eprint={2110.14168}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +### Groups and Tasks + +#### Groups + +- `math_word_problems` +- `chain_of_thought` +- `self_consistency` + +#### Tasks + +- `gsm8k_yaml` +- `gsm8k_cot`: GSM8K with Chain-of-Thought +- `gsm8k_cot_self_consistency`: GSM8K with Chain-of-Thought and Self-Consistency +- `gsm8k_cot_llama`: GSM8K with prompt formatting modified to conform to the evaluation settings described by Meta here: https://huggingface.co/datasets/meta-llama/Meta-Llama-3.1-8B-Instruct-evals/viewer/Meta-Llama-3.1-8B-Instruct-evals__gsm8k__details?row=0 + - Use this task with --fewshot_as_multiturn and --apply_chat_template to replicate Meta's reported performance. + + +### Checklist + +- [x] Is in Eval-harness v1.0 ? +- [ ] Has been checked for regression from v1.0? +- [ ] Has been checked for equivalence with original paper methodology? +- [ ] "Main" checked variant clearly denoted? + +### Variant Wishlist + +- [ ] Variant with Calculator (see https://github.com/openai/grade-school-math/blob/master/grade_school_math/calculator.py for example implementation) +- [ ] Using Verifiers +- [ ] Majority voting "without CoT" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-llama.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-llama.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e7948eeb8e3e7039f0c9c1738ac89aa19f4c4bb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-llama.yaml @@ -0,0 +1,84 @@ +dataset_name: main +dataset_path: gsm8k +doc_to_target: '{{answer.split(''####'')[-1].strip() if answer is defined else target}}' +doc_to_text: "Given the following problem, reason and give a final answer to the problem.\nProblem: {{question}}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n" +fewshot_config: + sampler: first_n + samples: + - question: There are 15 trees in the grove. Grove workers will plant trees in the + grove today. After they are done, there will be 21 trees. How many trees did + the grove workers plant today? + target: There are 15 trees originally. Then there were 21 trees after some more + were planted. So there must have been 21 - 15 = 6. The final answer is 6 + - question: If there are 3 cars in the parking lot and 2 more cars arrive, how many + cars are in the parking lot? + target: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The final answer + is 5 + - question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many + pieces do they have left in total? + target: Originally, Leah had 32 chocolates. Her sister had 42. So in total they + had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The final answer is 39 + - question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 + lollipops. How many lollipops did Jason give to Denny? + target: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. + So he gave Denny 20 - 12 = 8. The final answer is 8 + - question: Shawn has five toys. For Christmas, he got two toys each from his mom and + dad. How many toys does he have now? + target: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, + then that is 4 more toys. 5 + 4 = 9. The final answer is 9 + - question: There were nine computers in the server room. Five more computers were + installed each day, from monday to thursday. How many computers are now in the + server room? + target: There were originally 9 computers. For each of 4 days, 5 more computers + were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The final answer is + 29 + - question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, + he lost 2 more. How many golf balls did he have at the end of wednesday? + target: Michael started with 58 golf balls. After losing 23 on tuesday, he had + 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The final answer + is 33 + - question: Olivia has $23. She bought five bagels for $3 each. How much money does + she have left? + target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 + dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8 +filter_list: +- filter: + - function: regex + group_select: -1 + regex_pattern: The final answer is ((-?[$0-9.,]{2,})|(-?[0-9]+)) + - function: take_first + name: strict-match +- filter: + - function: regex + group_select: -1 + regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+) + - function: take_first + name: flexible-extract +generation_kwargs: + do_sample: false + until: + - '<|eot_id|>' + - '<|start_header_id|>user<|end_header_id|>' + - 'Q:' + - + - <|im_end|> +tag: +- chain_of_thought +metadata: + version: 3.0 +metric_list: +- aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: false + metric: exact_match + regexes_to_ignore: + - ',' + - \$ + - '(?s).*#### ' + - \.$ +num_fewshot: 8 +output_type: generate_until +repeats: 1 +task: gsm8k_cot_llama +test_split: test diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0994081b049c0815ae85b9539b627e4c8df00dd3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-self-consistency.yaml @@ -0,0 +1,34 @@ +include: gsm8k-cot.yaml +tag: + - chain_of_thought + - self_consistency +task: gsm8k_cot_self_consistency +generation_kwargs: + until: + - "Q:" + - "\n\n" + do_sample: true + temperature: 0.2 +repeats: 64 +filter_list: + - name: "score-first" # pick only the first response, and report metrics on that + filter: + - function: "regex" + regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)" + - function: "take_first" + - name: "maj@64" + filter: + - function: "regex" + regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)" + - function: "majority_vote" + - function: "take_first" + - name: "maj@8" # get Maj@8 , via selecting the first 8 responses. Using a better estimator would be optimal. + filter: + - function: "take_first_k" + k: 8 + - function: "regex" + regex_pattern: "The answer is (\\-?[0-9\\.\\,]*[0-9]+)" + - function: "majority_vote" + - function: "take_first" +metadata: + version: 2.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-zeroshot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-zeroshot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c112d324acf707e5934432068abd2ad6143438ac --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot-zeroshot.yaml @@ -0,0 +1,44 @@ +tag: + - math_word_problems +task: gsm8k_cot_zeroshot +dataset_path: gsm8k +dataset_name: main +output_type: generate_until +training_split: train +fewshot_split: train +test_split: test +doc_to_text: "Q: {{question}}\nA: Let's think step by step." +doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: false + regexes_to_ignore: + - "," + - "\\$" + - "(?s).*#### " + - "\\.$" +generation_kwargs: + until: + - "Q:" + - "" + - "<|im_end|>" + do_sample: false +repeats: 1 +num_fewshot: 0 +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "The answer is (\\-?[0-9\\.\\,]+)." + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "regex" + group_select: -1 + regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)" + - function: "take_first" +metadata: + version: 3.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d125b0198535122fd5b12a388e903b03ee5f6020 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k-cot.yaml @@ -0,0 +1,83 @@ +dataset_name: main +dataset_path: gsm8k +doc_to_target: '{{answer.split(''####'')[-1].strip() if answer is defined else target}}' +doc_to_text: 'Q: {{question}} + + A:' +fewshot_config: + sampler: first_n + samples: + - question: There are 15 trees in the grove. Grove workers will plant trees in the + grove today. After they are done, there will be 21 trees. How many trees did + the grove workers plant today? + target: There are 15 trees originally. Then there were 21 trees after some more + were planted. So there must have been 21 - 15 = 6. The answer is 6. + - question: If there are 3 cars in the parking lot and 2 more cars arrive, how many + cars are in the parking lot? + target: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer + is 5. + - question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many + pieces do they have left in total? + target: Originally, Leah had 32 chocolates. Her sister had 42. So in total they + had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39. + - question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 + lollipops. How many lollipops did Jason give to Denny? + target: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. + So he gave Denny 20 - 12 = 8. The answer is 8. + - question: Shawn has five toys. For Christmas, he got two toys each from his mom and + dad. How many toys does he have now? + target: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, + then that is 4 more toys. 5 + 4 = 9. The answer is 9. + - question: There were nine computers in the server room. Five more computers were + installed each day, from monday to thursday. How many computers are now in the + server room? + target: There were originally 9 computers. For each of 4 days, 5 more computers + were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is + 29. + - question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, + he lost 2 more. How many golf balls did he have at the end of wednesday? + target: Michael started with 58 golf balls. After losing 23 on tuesday, he had + 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer + is 33. + - question: Olivia has $23. She bought five bagels for $3 each. How much money does + she have left? + target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 + dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8. +filter_list: +- filter: + - function: regex + regex_pattern: The answer is (\-?[0-9\.\,]+). + - function: take_first + name: strict-match +- filter: + - function: regex + group_select: -1 + regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+) + - function: take_first + name: flexible-extract +generation_kwargs: + do_sample: false + until: + - 'Q:' + - + - <|im_end|> +tag: +- chain_of_thought +metadata: + version: 3.0 +metric_list: +- aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: false + metric: exact_match + regexes_to_ignore: + - ',' + - \$ + - '(?s).*#### ' + - \.$ +num_fewshot: 8 +output_type: generate_until +repeats: 1 +task: gsm8k_cot +test_split: test diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9d5bb39aedc0e2b991f0d79f2de6face47a31cf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/gsm8k/gsm8k.yaml @@ -0,0 +1,45 @@ +tag: + - math_word_problems +task: gsm8k +dataset_path: gsm8k +dataset_name: main +output_type: generate_until +training_split: train +fewshot_split: train +test_split: test +doc_to_text: "Question: {{question}}\nAnswer:" +doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: false + regexes_to_ignore: + - "," + - "\\$" + - "(?s).*#### " + - "\\.$" +generation_kwargs: + until: + - "Question:" + - "" + - "<|im_end|>" + do_sample: false + temperature: 0.0 +repeats: 1 +num_fewshot: 5 +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "#### (\\-?[0-9\\.\\,]+)" + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "regex" + group_select: -1 + regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)" + - function: "take_first" +metadata: + version: 3.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..44a2cb829233370020319f39e6ae7323e601aabb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/README.md @@ -0,0 +1,52 @@ +# HumanEval + +## Paper +Evaluating Large Language Models Trained on Code +https://arxiv.org/abs/2107.03374 + +We introduce Codex, a GPT language model fine-tuned on publicly available code from GitHub, and study its Python code-writing capabilities. A distinct production version of Codex powers GitHub Copilot. On HumanEval, a new evaluation set we release to measure functional correctness for synthesizing programs from docstrings, our model solves 28.8% of the problems, while GPT-3 solves 0% and GPT-J solves 11.4%. Furthermore, we find that repeated sampling from the model is a surprisingly effective strategy for producing working solutions to difficult prompts. Using this method, we solve 70.2% of our problems with 100 samples per problem. Careful investigation of our model reveals its limitations, including difficulty with docstrings describing long chains of operations and with binding operations to variables. Finally, we discuss the potential broader impacts of deploying powerful code generation technologies, covering safety, security, and economics. + +Homepage: https://github.com/openai/human-eval + +Note: For instruct tuned models, we recommend the instruct variant. That uses a gen_prefix to ensure the model completes the partial code snippet (might not work with all APIs) + +## Citation +``` +@article{chen2021codex, + title={Evaluating Large Language Models Trained on Code}, + author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba}, + year={2021}, + eprint={2107.03374}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` + +### Groups and Tasks + +#### Groups + +* Not part of a group yet. + +#### Tasks + +- `humaneval` pass@1 +- `humaneval_64` pass@64 variant +- `humaneval_instruct`: pass@1 with config more appropriate for instruct models. (implementation taken from llama [evals](https://huggingface.co/datasets/meta-llama/Llama-3.1-8B-Instruct-evals/viewer/Llama-3.1-8B-Instruct-evals__human_eval__details?row=0)) +- `humaneval_instruct_64`: pass@64 variant + +### Checklist + +For adding novel benchmarks/datasets to the library: +* [ ] Is the task an existing benchmark in the literature? + * [ ] Have you referenced the original paper that introduced the task? + * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? + +### Changelog +v2 20-MAR-2025: `humaneval_instruct`, `humaneval_instruct_64`: fixed typo in gen_prefix diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b22f67b77ce67db8de66bb8ade37e173a89854b2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval.yaml @@ -0,0 +1,26 @@ +task: humaneval +dataset_path: openai/openai_humaneval +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "{{prompt}}" +doc_to_target: "{{test}}\ncheck({{entry_point}})" +metric_list: + - metric: !function utils.pass_at_k + aggregation: mean + higher_is_better: true + k: [1] +generation_kwargs: + until: + - "[DONE]" + max_gen_toks: 1024 + do_sample: false +repeats: 1 +num_fewshot: 0 +filter_list: + - name: "create_test" + filter: + - function: "custom" + filter_fn: !function utils.build_predictions +metadata: + version: 1.0 \ No newline at end of file diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_5_instruct.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_5_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..953d10de29fa6044b7db2a160f162fd67c945b19 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_5_instruct.yaml @@ -0,0 +1,11 @@ +include: humaneval_5.yaml +task: humaneval_5_instruct +doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n```{{prompt}}" +gen_prefix: "Here is the completed function:\n```python\n{{prompt}}\n" +filter_list: + - name: "create_test" + filter: + - function: "custom" + filter_fn: !function utils.build_predictions_instruct +metadata: + version: 2.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_64_instruct.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_64_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca0f38c31e8d6b8d6b3ae8e7847fd6141f187492 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_64_instruct.yaml @@ -0,0 +1,11 @@ +include: humaneval_64.yaml +task: humaneval_64_instruct +doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n```{{prompt}}" +gen_prefix: "Here is the completed function:\n```python\n{{prompt}}\n" +filter_list: + - name: "create_test" + filter: + - function: "custom" + filter_fn: !function utils.build_predictions_instruct +metadata: + version: 2.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_plus.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_plus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e31772050e8ecc713827804cba3b81c3da3cfc2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/humaneval/humaneval_plus.yaml @@ -0,0 +1,3 @@ +include: humaneval.yaml +task: humaneval_plus +dataset_path: evalplus/humanevalplus diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aced2e78bb5adfd0ff413b4ee72c53e0fa3d5cc6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/README.md @@ -0,0 +1,45 @@ +# IFEval + +### Paper + +Title: Instruction-Following Evaluation for Large Language Models +Abstract: https://arxiv.org/abs/2311.07911 + +One core capability of Large Language Models (LLMs) is to follow natural language instructions. However, the evaluation of such abilities is not standardized: Human evaluations are expensive, slow, and not objectively reproducible, while LLM-based auto-evaluation is potentially biased or limited by the ability of the evaluator LLM. To overcome these issues, we introduce Instruction-Following Eval (IFEval) for large language models. IFEval is a straightforward and easy-to-reproduce evaluation benchmark. It focuses on a set of "verifiable instructions" such as "write in more than 400 words" and "mention the keyword of AI at least 3 times". We identified 25 types of those verifiable instructions and constructed around 500 prompts, with each prompt containing one or more verifiable instructions. We show evaluation results of two widely available LLMs on the market. Our code and data can be found at https://github.com/google-research/google-research/tree/master/instruction_following_eval + +Homepage: https://github.com/google-research/google-research/tree/master/instruction_following_eval + + +### Citation + +``` +@article{zhou2023instructionfollowing, + title={Instruction-Following Evaluation for Large Language Models}, + author={Jeffrey Zhou and Tianjian Lu and Swaroop Mishra and Siddhartha Brahma and Sujoy Basu and Yi Luan and Denny Zhou and Le Hou}, + journal={arXiv preprint arXiv:2311.07911}, + year={2023}, +} +``` + +### Groups and Tasks + +#### Groups + +* Not part of a group yet + +#### Tasks + +* `ifeval` + +### Checklist + +For adding novel benchmarks/datasets to the library: +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/ifeval.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/ifeval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..508a63a9452874109cd949f5d5a5e00ad5f66b36 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/ifeval.yaml @@ -0,0 +1,29 @@ +task: ifeval +dataset_path: google/IFEval +dataset_name: null +output_type: generate_until +test_split: train +num_fewshot: 0 +doc_to_text: prompt +doc_to_target: 0 +generation_kwargs: + until: [] + do_sample: false + temperature: 0.0 + max_gen_toks: 1280 +process_results: !function utils.process_results +metric_list: + - metric: prompt_level_strict_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_strict_acc + aggregation: !function utils.agg_inst_level_acc + higher_is_better: true + - metric: prompt_level_loose_acc + aggregation: mean + higher_is_better: true + - metric: inst_level_loose_acc + aggregation: !function utils.agg_inst_level_acc + higher_is_better: true +metadata: + version: 4.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7bcce13b0f29b829f21dea14b8f7ce5baeaac1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions.py @@ -0,0 +1,1612 @@ +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library of instructions.""" + +import collections +import json +import logging +import random +import re +import string +from typing import Dict, Optional, Sequence, Union + +import langdetect + +from lm_eval.tasks.ifeval import instructions_util + + +logger = logging.getLogger(__name__) + +_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]] + +_LANGUAGES = instructions_util.LANGUAGE_CODES + +# The relational operation for comparison. +_COMPARISON_RELATION = ("less than", "at least") + +# The maximum number of sentences. +_MAX_NUM_SENTENCES = 20 + +# The number of placeholders. +_NUM_PLACEHOLDERS = 4 + +# The number of bullet lists. +_NUM_BULLETS = 5 + +# The options of constrained response. +_CONSTRAINED_RESPONSE_OPTIONS = ( + "My answer is yes.", + "My answer is no.", + "My answer is maybe.", +) + +# The options of starter keywords. +_STARTER_OPTIONS = ( + "I would say", + "My answer is", + "I believe", + "In my opinion", + "I think", + "I reckon", + "I feel", + "From my perspective", + "As I see it", + "According to me", + "As far as I'm concerned", + "To my understanding", + "In my view", + "My take on it is", + "As per my perception", +) + +# The options of ending keywords. +# TODO(jeffreyzhou) add more ending options +_ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?") + +# The number of highlighted sections. +_NUM_HIGHLIGHTED_SECTIONS = 4 + +# The section splitter. +_SECTION_SPLITER = ("Section", "SECTION") + +# The number of sections. +_NUM_SECTIONS = 5 + +# The number of paragraphs. +_NUM_PARAGRAPHS = 5 + +# The postscript marker. +_POSTSCRIPT_MARKER = ("P.S.", "P.P.S") + +# The number of keywords. +_NUM_KEYWORDS = 2 + +# The occurrences of a single keyword. +_KEYWORD_FREQUENCY = 3 + +# The occurrences of a single letter. +_LETTER_FREQUENCY = 10 + +# The occurrences of words with all capital letters. +_ALL_CAPITAL_WORD_FREQUENCY = 20 + +# The number of words in the response. +_NUM_WORDS_LOWER_LIMIT = 100 +_NUM_WORDS_UPPER_LIMIT = 500 + + +class Instruction: + """An instruction template.""" + + def __init__(self, instruction_id): + self.id = instruction_id + + def build_description(self, **kwargs): + raise NotImplementedError("`build_description` not implemented.") + + def get_instruction_args(self): + raise NotImplementedError("`get_instruction_args` not implemented.") + + def get_instruction_args_keys(self): + raise NotImplementedError("`get_instruction_args_keys` not implemented.") + + def check_following(self, value): + raise NotImplementedError("`check_following` not implemented.") + + +class ResponseLanguageChecker(Instruction): + """Check the language of the entire response.""" + + def build_description(self, *, language=None): + """Build the instruction description. + + Args: + language: A string representing the expected language of the response. The + language has to comply to the 97 types defined in + `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows + ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); + for example, `en` for English, `zh` for Chinese, `fr` for French. + + Returns: + A string representing the instruction description. + """ + self._language = language + if self._language is None: + self._language = random.choice(list(_LANGUAGES.keys())) + # TODO(tianjianlu): opens the description generation to more choices. + self._description_pattern = ( + "Your ENTIRE response should be in {language} language, no other " + + "language is allowed." + ) + return self._description_pattern.format(language=_LANGUAGES[self._language]) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"language": self._language} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["language"] + + def check_following(self, value): + """Check if the language of the entire response follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the language of `value` follows instruction; otherwise False. + """ + assert isinstance(value, str) + + try: + return langdetect.detect(value) == self._language + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logging.error( + "Unable to detect language for text %s due to %s", value, e + ) # refex: disable=pytotw.037 + return True + + +class NumberOfSentences(Instruction): + """Check the number of sentences.""" + + def build_description(self, *, num_sentences=None, relation=None): + """Build the instruction description. + + Args: + num_sentences: An integer specifying the number of sentences as a + threshold. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of sentences < the threshold; + if 'at least', the actual number of sentences >= the threshold. + + Returns: + A string representing the instruction description. + """ + # The number of sentences as a threshold for comparison. + self._num_sentences_threshold = num_sentences + if self._num_sentences_threshold is None or self._num_sentences_threshold < 0: + self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = ( + "Your response should contain {relation} {num_sentences} sentences." + ) + return self._description_pattern.format( + relation=self._comparison_relation, + num_sentences=self._num_sentences_threshold, + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return { + "num_sentences": self._num_sentences_threshold, + "relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_sentences", "relation"] + + def check_following(self, value): + """Check if the number of sentences follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the response follows the instruction. + + Raise: + ValueError if the string in `instruction_args` is not in + [`less_than`, `at_least`]. + """ + num_sentences = instructions_util.count_sentences(value) + if self._comparison_relation == _COMPARISON_RELATION[0]: + return num_sentences < self._num_sentences_threshold + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return num_sentences >= self._num_sentences_threshold + + +class PlaceholderChecker(Instruction): + """Check the placeholders in template writing.""" + + def build_description(self, *, num_placeholders=None): + """Build the instruction description. + + Args: + num_placeholders: An integer denoting the minimum number of + placeholders required in the response. + + Returns: + A string representing the instruction description. + """ + self._num_placeholders = num_placeholders + if self._num_placeholders is None or self._num_placeholders < 0: + self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS) + self._description_pattern = ( + "The response must contain at least {num_placeholders} placeholders " + + "represented by square brackets, such as [address]." + ) + return self._description_pattern.format(num_placeholders=self._num_placeholders) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"num_placeholders": self._num_placeholders} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_placeholders"] + + def check_following(self, value): + """Check if the number of placeholders follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the actual number of placeholders in the response is greater than + or equal to `num_placeholders`; otherwise, False. + """ + placeholders = re.findall(r"\[.*?\]", value) + num_placeholders = len(placeholders) + return num_placeholders >= self._num_placeholders + + +class BulletListChecker(Instruction): + """Checks the bullet list in the prompt.""" + + def build_description(self, *, num_bullets=None): + """Build the instruction description. + + Args: + num_bullets: An integer specifying the exact number of bullet lists + that is required to appear in the response. + + Returns: + A string representing the instruction description. + """ + self._num_bullets = num_bullets + if self._num_bullets is None or self._num_bullets < 0: + self._num_bullets = random.randint(1, _NUM_BULLETS) + self._description_pattern = ( + "Your answer must contain exactly {num_bullets} bullet points. " + + "Use the markdown bullet points such as:\n" + + "* This is point 1. \n" + + "* This is point 2" + ) + return self._description_pattern.format(num_bullets=self._num_bullets) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"num_bullets": self._num_bullets} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_bullets"] + + def check_following(self, value): + r"""Check if the number of bullet lists meets the requirement. + + Args: + value: A string representing the response. The response is expected to + contain some bullet lists that start with `\*`. + + Returns: + True if the actual number of bullet lists in the response meets the + requirement. + """ + bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) + bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) + num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) + return num_bullet_lists == self._num_bullets + + +class ConstrainedResponseChecker(Instruction): + """Checks the constrained response.""" + + def build_description(self): + """Build the instruction description.""" + # A sequence of string(s) representing the options of the expected response. + self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS + self._description_pattern = ( + "Answer with one of the following options: {response_options}" + ) + return self._description_pattern.format( + response_options=self._constrained_responses + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response matches the constrained options. + + Args: + value: A string representing the response. + + Returns: + True if the actual response contains one of the options in the constrained + responses; otherwise False. + """ + value = value.strip() + for constrained_response in self._constrained_responses: + if constrained_response in value: + return True + return False + + +class ConstrainedStartChecker(Instruction): + """Checks the response start.""" + + def build_description(self, *, starter=None): + """Build the instruction description. + + Args: + starter: A string representing the keyword that the response should start + with. + + Returns: + A string representing the instruction description. + """ + self._starter = starter.strip() if isinstance(starter, str) else starter + if self._starter is None: + self._starter = random.choice(_STARTER_OPTIONS) + self._description_pattern = ( + "During the conversation, when it is your turn, " + + "please always start with {starter}" + ) + return self._description_pattern.format(starter=self._starter) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"starter": self._starter} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["starter"] + + def check_following(self, value): + """Checks if the response starts with the constrained keyword or phrase. + + Args: + value: A string representing the response. + + Returns: + True if the response starts with the given phrase or keyword that is + contained in `instruction_args`; otherwise, False. + """ + response_pattern = r"^\s*" + self._starter + r".*$" + response_with_constrained_start = re.search( + response_pattern, value, flags=re.MULTILINE + ) + return True if response_with_constrained_start else False + + +class HighlightSectionChecker(Instruction): + """Checks the highlighted section.""" + + def build_description(self, *, num_highlights=None): + """Build the instruction description. + + Args: + num_highlights: An integer specifying the minimum number of highlighted + sections. + + Returns: + A string representing the instruction description. + """ + self._num_highlights = num_highlights + if self._num_highlights is None or self._num_highlights < 0: + self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS) + + self._description_pattern = ( + "Highlight at least {num_highlights} sections in your answer with " + + "markdown, i.e. *highlighted section*." + ) + + return self._description_pattern.format(num_highlights=self._num_highlights) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"num_highlights": self._num_highlights} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_highlights"] + + def check_following(self, value): + """Checks if the number of highlighted sections meets the requirement. + + Args: + value: a string representing the response. The response is expected to + contain highlighted sections in the format of *highlighted*. + + Returns: + True if the actual number of highlighted sections in the format of + *highlighted sections* meets the minimum requirement; otherwise False. + """ + num_highlights = 0 + highlights = re.findall(r"\*[^\n\*]*\*", value) + double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value) + for highlight in highlights: + if highlight.strip("*").strip(): + num_highlights += 1 + for highlight in double_highlights: + if highlight.removeprefix("**").removesuffix("**").strip(): + num_highlights += 1 + + return num_highlights >= self._num_highlights + + +class SectionChecker(Instruction): + """Checks the sections.""" + + def build_description(self, *, section_spliter=None, num_sections=None): + """Build the instruction description. + + Args: + section_spliter: A string represents the section spliter keyword that + marks a new section, i.e., `Section` or `SECTION`. + num_sections: An integer specifying the number of sections. + + Returns: + A string representing the instruction description. + """ + self._section_spliter = ( + section_spliter.strip() + if isinstance(section_spliter, str) + else section_spliter + ) + if self._section_spliter is None: + self._section_spliter = random.choice(_SECTION_SPLITER) + + self._num_sections = num_sections + if self._num_sections is None or self._num_sections < 0: + self._num_sections = random.randint(1, _NUM_SECTIONS) + + self._description_pattern = ( + "Your response must have {num_sections} sections. Mark the beginning " + + "of each section with {section_spliter} X, such as:\n" + + "{section_spliter} 1\n" + + "[content of section 1]\n" + + "{section_spliter} 2\n" + + "[content of section 2]" + ) + + return self._description_pattern.format( + num_sections=self._num_sections, section_spliter=self._section_spliter + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return { + "section_spliter": self._section_spliter, + "num_sections": self._num_sections, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["section_spliter", "num_sections"] + + def check_following(self, value): + """Checks the response contains multiple sections. + + Args: + value: A string representing the response. The response is expected + to contain multiple sections (number of sections is greater than 1). + A new section starts with `Section 1`, where the number denotes the + section index. + + Returns: + True if the number of sections in the response is greater than or equal to + the minimum number of sections; otherwise, False. + """ + section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" + sections = re.split(section_splitter_patten, value) + num_sections = len(sections) - 1 + return num_sections >= self._num_sections + + +class ParagraphChecker(Instruction): + """Checks the paragraphs.""" + + def build_description(self, *, num_paragraphs=None): + """Build the instruction description. + + Args: + num_paragraphs: An integer specifying the number of paragraphs. + + Returns: + A string representing the instruction description. + """ + self._num_paragraphs = num_paragraphs + if self._num_paragraphs is None or self._num_paragraphs < 0: + self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) + + self._description_pattern = ( + "There should be {num_paragraphs} paragraphs. " + + "Paragraphs are separated with the markdown divider: ***" + ) + + return self._description_pattern.format(num_paragraphs=self._num_paragraphs) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"num_paragraphs": self._num_paragraphs} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_paragraphs"] + + def check_following(self, value): + """Checks the response contains required number of paragraphs. + + Args: + value: A string representing the response. The response may contain + paragraphs that are separated by the markdown divider: `***`. + + Returns: + True if the actual number of paragraphs is the same as required; + otherwise, False. + """ + paragraphs = re.split(r"\s?\*\*\*\s?", value) + num_paragraphs = len(paragraphs) + + for index, paragraph in enumerate(paragraphs): + if not paragraph.strip(): + if index == 0 or index == len(paragraphs) - 1: + num_paragraphs -= 1 + else: + return False + + return num_paragraphs == self._num_paragraphs + + +class PostscriptChecker(Instruction): + """Checks the postscript.""" + + def build_description(self, *, postscript_marker=None): + """Build the instruction description. + + Args: + postscript_marker: A string containing the keyword that marks the start + of the postscript section. + + Returns: + A string representing the instruction description. + """ + self._postscript_marker = ( + postscript_marker.strip() + if isinstance(postscript_marker, str) + else postscript_marker + ) + if self._postscript_marker is None: + self._postscript_marker = random.choice(_POSTSCRIPT_MARKER) + + self._description_pattern = ( + "At the end of your response, please explicitly add a postscript " + + "starting with {postscript}" + ) + + return self._description_pattern.format(postscript=self._postscript_marker) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"postscript_marker": self._postscript_marker} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["postscript_marker"] + + def check_following(self, value): + """Checks if the response follows the postscript format. + + Args: + value: a string representing the response. The response is expected to + contain a postscript section. + + Returns: + True if the response contains a postscript section starting with + the keyword containing in the `instruction_args`; otherwise False. + """ + value = value.lower() + if self._postscript_marker == "P.P.S": + postscript_pattern = r"\s*p\.\s?p\.\s?s.*$" + elif self._postscript_marker == "P.S.": + postscript_pattern = r"\s*p\.\s?s\..*$" + else: + postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" + postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) + return True if postscript else False + + +class RephraseChecker(Instruction): + """Checks the rephrase.""" + + def build_description(self, *, original_message): + """Build the instruction description. + + Args: + original_message: A string representing the original message. The + rephrased response should only change its words/sentences in between + its two asterisks, for example, *change me*. Both original and rephrased + messages should contain the changes in the form of *change me*. + + Returns: + A string representing the instruction description. + """ + if not self.is_change(original_message): + raise ValueError( + f"Message {original_message} does not contain changes " + "in the form of *change me*." + ) + + self._reference_without_change = original_message + self._description = ( + "Rephrasing: Your rephrased response should only" + + "change the words/sentences in between two asterisks" + + "such as *change me*." + ) + return self._description + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"original_message": self._reference_without_change} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["original_message"] + + def check_following(self, value): + r"""Checks if the rephrasing follows the instruction. + + Args: + value: A string representing the response, which is expected to rephras + the string of `instruction_args`. + + Returns: + True if `value` and `instruction_args` only differ by the words/sentences + in between two asterisks such as *change me*; otherwise, False. + """ + + if not self.is_change(value): + raise ValueError( + f"value {value} does not contain changes in the form of *change me*." + ) + + response_without_changes = self.strip_changes(value) + reference_without_changes = self.strip_changes(self._reference_without_change) + + return response_without_changes == reference_without_changes + + def is_change(self, response): + """Check if there is change in the response in the form of *change me*.""" + return re.search(r"\*.*\*", response) + + def strip_changes(self, response): + """Strips off the changes.""" + return re.sub(r"\*.*\*", "", response) + + +class KeywordChecker(Instruction): + """Check the exisitence of certain keywords.""" + + def build_description(self, *, keywords=None): + """Build the instruction description. + + Args: + keywords: A sequence of strings representing the keywords that are + expected in the response. + + Returns: + A string representing the instruction description. + """ + + if not keywords: + self._keywords = instructions_util.generate_keywords( + num_keywords=_NUM_KEYWORDS + ) + else: + self._keywords = keywords + self._keywords = sorted(self._keywords) + + self._description_pattern = "Include keywords {keywords} in the response." + + return self._description_pattern.format(keywords=self._keywords) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"keywords": self._keywords} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keywords"] + + def check_following(self, value): + """Check if the response contain the expected keywords.""" + for keyword in self._keywords: + if not re.search(keyword, value, flags=re.IGNORECASE): + return False + return True + + +class KeywordFrequencyChecker(Instruction): + """Check the keyword frequency.""" + + def build_description(self, *, keyword=None, frequency=None, relation=None): + """Build the instruction description. + + Args: + keyword: A string representing a keyword that is expected in the response. + frequency: An integer specifying the number of times `keyword` is expected + to appear in the response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of occurrences < frequency; + if 'at least', the actual number of occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if not keyword: + self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] + else: + self._keyword = keyword.strip() + + self._frequency = frequency + if self._frequency is None or self._frequency < 0: + self._frequency = random.randint(1, _KEYWORD_FREQUENCY) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = ( + "In your response, the word {keyword} should appear {relation} " + + "{frequency} times." + ) + + return self._description_pattern.format( + keyword=self._keyword, + relation=self._comparison_relation, + frequency=self._frequency, + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return { + "keyword": self._keyword, + "frequency": self._frequency, + "relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword", "frequency", "relation"] + + def check_following(self, value): + """Checks if the response contain the keyword with required frequency.""" + actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return actual_occurrences < self._frequency + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return actual_occurrences >= self._frequency + + +class NumberOfWords(Instruction): + """Checks the number of words.""" + + def build_description(self, *, num_words=None, relation=None): + """Build the instruction description. + + Args: + num_words: An integer specifying the number of words contained in the + response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of words < num_words; + if 'at least', the actual number of words >= num_words. + + Returns: + A string representing the instruction description. + """ + + self._num_words = num_words + if self._num_words is None or self._num_words < 0: + self._num_words = random.randint( + _NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT + ) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = "Answer with {relation} {num_words} words." + + return self._description_pattern.format( + relation=self._comparison_relation, num_words=self._num_words + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"num_words": self._num_words, "relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_words", "relation"] + + def check_following(self, value): + """Checks if the response contains the expected number of words.""" + num_words = instructions_util.count_words(value) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return num_words < self._num_words + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return num_words >= self._num_words + + +class JsonFormat(Instruction): + """Check the Json format.""" + + def build_description(self): + self._description_pattern = ( + "Entire output should be wrapped in JSON format. You can use markdown" + " ticks such as ```." + ) + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + value = ( + value.strip() + .removeprefix("```json") + .removeprefix("```Json") + .removeprefix("```JSON") + .removeprefix("```") + .removesuffix("```") + .strip() + ) + try: + json.loads(value) + except ValueError: + return False + return True + + +class ParagraphFirstWordCheck(Instruction): + """Check the paragraph and the first word of the nth paragraph.""" + + def build_description( + self, num_paragraphs=None, nth_paragraph=None, first_word=None + ): + r"""Build the instruction description. + + Args: + num_paragraphs: An integer indicating the number of paragraphs expected + in the response. A paragraph is a subset of the string that is + expected to be separated by '\n\n'. + nth_paragraph: An integer indicating the paragraph number that we look at. + Note that n starts from 1. + first_word: A string that represent the first word of the bth paragraph. + + Returns: + A string representing the instruction description. + """ + self._num_paragraphs = num_paragraphs + if self._num_paragraphs is None or self._num_paragraphs < 0: + self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) + + self._nth_paragraph = nth_paragraph + if ( + self._nth_paragraph is None + or self._nth_paragraph <= 0 + or self._nth_paragraph > self._num_paragraphs + ): + self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) + + self._first_word = first_word + if self._first_word is None: + self._first_word = instructions_util.generate_keywords(num_keywords=1)[0] + self._first_word = self._first_word.lower() + + self._description_pattern = ( + "There should be {num_paragraphs} paragraphs. " + + "Paragraphs and only paragraphs are separated with each other by two " + + "new lines as if it was '\\n\\n' in python. " + + "Paragraph {nth_paragraph} must start with word {first_word}." + ) + + return self._description_pattern.format( + num_paragraphs=self._num_paragraphs, + nth_paragraph=self._nth_paragraph, + first_word=self._first_word, + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return { + "num_paragraphs": self._num_paragraphs, + "nth_paragraph": self._nth_paragraph, + "first_word": self._first_word, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_paragraphs", "nth_paragraph", "first_word"] + + def check_following(self, value): + """Checks for required number of paragraphs and correct first word. + + Args: + value: a string representing the response. The response may contain + paragraphs that are separated by two new lines and the first word of + the nth paragraph will have to match a specified word. + + Returns: + True if the number of paragraphs is the same as required and the first + word of the specified paragraph is the same as required. Otherwise, false. + """ + + paragraphs = re.split(r"\n\n", value) + num_paragraphs = len(paragraphs) + + for paragraph in paragraphs: + if not paragraph.strip(): + num_paragraphs -= 1 + + # check that index doesn't go out of bounds + if self._nth_paragraph <= num_paragraphs: + paragraph = paragraphs[self._nth_paragraph - 1].strip() + if not paragraph: + return False + else: + return False + + first_word = "" + punctuation = {".", ",", "?", "!", "'", '"'} + + # get first word and remove punctuation + word = paragraph.split()[0].strip() + # TODO(jeffrey): make more complex? + word = word.lstrip("'") + word = word.lstrip('"') + + for letter in word: + if letter in punctuation: + break + first_word += letter.lower() + + return num_paragraphs == self._num_paragraphs and first_word == self._first_word + + +# TODO(jeffrey) add relation - at least/at most? +class KeySentenceChecker(Instruction): + """Check the existence of certain key sentences.""" + + def build_description(self, key_sentences=None, num_sentences=None): + """Build the instruction description. + + Args: + key_sentences: A sequences of strings representing the key sentences that + are expected in the response. + num_sentences: The number of key sentences that are expected to be seen in + the response. + + Returns: + A string representing the instruction description. + """ + + if not key_sentences: + # TODO(jeffrey) make a generate sentences function? wonderwords package + self._key_sentences = set(["For now, this is fine."]) + else: + self._key_sentences = key_sentences + + if not num_sentences: + self._num_sentences = random.randint(1, len(self._key_sentences)) + else: + self._num_sentences = num_sentences + + self._description_pattern = ( + "Include {num_sentences} of the following sentences {key_sentences}" + ) + + return self._description_pattern.format( + num_sentences=self._num_sentences, key_sentences=self._key_sentences + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return { + "num_sentences": self._num_sentences, + "key_sentences": list(self._key_sentences), + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_sentences", "key_sentences"] + + def check_following(self, value): + """Checks if the response contains the expected key sentences.""" + count = 0 + sentences = instructions_util.split_into_sentences(value) + for sentence in self._key_sentences: + if sentence in sentences: + count += 1 + + return count == self._num_sentences + + +class ForbiddenWords(Instruction): + """Checks that specified words are not used in response.""" + + def build_description(self, forbidden_words=None): + """Build the instruction description. + + Args: + forbidden_words: A sequences of strings representing words that are not + allowed in the response. + + Returns: + A string representing the instruction description. + """ + + if not forbidden_words: + self._forbidden_words = instructions_util.generate_keywords( + num_keywords=_NUM_KEYWORDS + ) + else: + self._forbidden_words = list(set(forbidden_words)) + self._forbidden_words = sorted(self._forbidden_words) + self._description_pattern = ( + "Do not include keywords {forbidden_words} in the response." + ) + + return self._description_pattern.format(forbidden_words=self._forbidden_words) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return {"forbidden_words": self._forbidden_words} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["forbidden_words"] + + def check_following(self, value): + """Check if the response does not contain the expected keywords.""" + for word in self._forbidden_words: + if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): + return False + return True + + +class RephraseParagraph(Instruction): + """Checks that the paragraph is rephrased.""" + + def build_description(self, *, original_paragraph, low, high): + """Builds the instruction description. + + Args: + original_paragraph: A string presenting the original paragraph. The + rephrases response should have betweeb low-high words in common. + low: An integer presenting the lower bound of similar words. + high: An integer representing the upper bound of similar words. + + Returns: + A string representing the instruction description. + """ + # TODO(jeffrey) make more encompassing + self._original_paragraph = original_paragraph + self._low = low + self._high = high + + self._description = ( + "Rephrase the following paragraph: " + + "{original_paragraph}\nYour response should have " + + "between {low} and {high} of the same words. " + + "Words are the same if and only if all of the " + + "letters, ignoring cases, are the same. For " + + "example, 'run' is the same as 'Run' but different " + + "to 'ran'." + ) + + return self._description.format( + original_paragraph=original_paragraph, low=self._low, high=self._high + ) + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return { + "original_paragraph": self._original_paragraph, + "low": self._low, + "high": self._high, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["original_paragraph", "low", "high"] + + def check_following(self, value): + val_words = re.findall(r"\w+", value.lower()) + original_words = re.findall(r"\w+", self._original_paragraph.lower()) + similar_words = 0 + + dict_val = collections.Counter(val_words) + dict_original = collections.Counter(original_words) + + for word in dict_original: + similar_words += min(dict_original[word], dict_val[word]) + + return similar_words >= self._low and similar_words <= self._high + + +class TwoResponsesChecker(Instruction): + """Check that two responses were given.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Give two different responses. Responses and only responses should" + " be separated by 6 asterisk symbols: ******." + ) + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyword args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response has two different answers. + + Args: + value: A string representing the response. + + Returns: + True if two responses are detected and false otherwise. + """ + valid_responses = list() + responses = value.split("******") + for index, response in enumerate(responses): + if not response.strip(): + if index != 0 and index != len(responses) - 1: + return False + else: + valid_responses.append(response) + return ( + len(valid_responses) == 2 + and valid_responses[0].strip() != valid_responses[1].strip() + ) + + +class RepeatPromptThenAnswer(Instruction): + """Checks that Prompt is first repeated then answered.""" + + def build_description(self, *, prompt_to_repeat=None): + """Build the instruction description. + + Args: + prompt_to_repeat: The prompt that is meant to be repeated. + + Returns: + A string representing the instruction description. + """ + if not prompt_to_repeat: + raise ValueError("prompt_to_repeat must be set.") + else: + self._prompt_to_repeat = prompt_to_repeat + self._description_pattern = ( + "First repeat the request word for word without change," + " then give your answer (1. do not say any words or characters" + " before repeating the request; 2. the request you need to repeat" + " does not include this sentence)" + ) + return self._description_pattern + + def get_instruction_args(self): + return {"prompt_to_repeat": self._prompt_to_repeat} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["prompt_to_repeat"] + + def check_following(self, value): + if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()): + return True + return False + + +class EndChecker(Instruction): + """Checks that the prompt ends with a given phrase.""" + + def build_description(self, *, end_phrase=None): + """Build the instruction description. + + Args: + end_phrase: A string representing the phrase the response should end with. + + Returns: + A string representing the instruction description. + """ + self._end_phrase = ( + end_phrase.strip() if isinstance(end_phrase, str) else end_phrase + ) + if self._end_phrase is None: + self._end_phrase = random.choice(_ENDING_OPTIONS) + self._description_pattern = ( + "Finish your response with this exact phrase {ender}. " + "No other words should follow this phrase." + ) + return self._description_pattern.format(ender=self._end_phrase) + + def get_instruction_args(self): + return {"end_phrase": self._end_phrase} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["end_phrase"] + + def check_following(self, value): + """Checks if the response ends with the expected phrase.""" + value = value.strip().strip('"').lower() + self._end_phrase = self._end_phrase.strip().lower() + return value.endswith(self._end_phrase) + + +class TitleChecker(Instruction): + """Checks the response for a title.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Your answer must contain a title, wrapped in double angular brackets," + " such as <>." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response contains a title.""" + pattern = r"<<[^\n]+>>" + re_pattern = re.compile(pattern) + titles = re.findall(re_pattern, value) + + for title in titles: + if title.lstrip("<").rstrip(">").strip(): + return True + return False + + +class LetterFrequencyChecker(Instruction): + """Checks letter frequency.""" + + def build_description(self, *, letter=None, let_frequency=None, let_relation=None): + """Build the instruction description. + + Args: + letter: A string representing a letter that is expected in the response. + let_frequency: An integer specifying the number of times `keyword` is + expected to appear in the response. + let_relation: A string in (`less than`, `at least`), defining the + relational operator for comparison. Two relational comparisons are + supported for now; if 'less than', the actual number of + occurrences < frequency; if 'at least', the actual number of + occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if ( + not letter + or len(letter) > 1 + or ord(letter.lower()) < 97 + or ord(letter.lower()) > 122 + ): + self._letter = random.choice(list(string.ascii_letters)) + else: + self._letter = letter.strip() + self._letter = self._letter.lower() + + self._frequency = let_frequency + if self._frequency is None or self._frequency < 0: + self._frequency = random.randint(1, _LETTER_FREQUENCY) + + if let_relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif let_relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {let_relation} is given." + ) + else: + self._comparison_relation = let_relation + + self._description_pattern = ( + "In your response, the letter {letter} should appear {let_relation}" + " {let_frequency} times." + ) + + return self._description_pattern.format( + letter=self._letter, + let_frequency=self._frequency, + let_relation=self._comparison_relation, + ) + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return { + "letter": self._letter, + "let_frequency": self._frequency, + "let_relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["letter", "let_frequency", "let_relation"] + + def check_following(self, value): + """Checks that the response contains the letter at the right frequency.""" + value = value.lower() + letters = collections.Counter(value) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return letters[self._letter] < self._frequency + else: + return letters[self._letter] >= self._frequency + + +class CapitalLettersEnglishChecker(Instruction): + """Checks that the response is in english and is in all capital letters.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Your entire response should be in English, and in all capital letters." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response is in English and in all capital letters.""" + assert isinstance(value, str) + + try: + return value.isupper() and langdetect.detect(value) == "en" + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logging.error( + "Unable to detect language for text %s due to %s", value, e + ) # refex: disable=pytotw.037 + return True + + +class LowercaseLettersEnglishChecker(Instruction): + """Checks that the response is in english and is in all lowercase letters.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Your entire response should be in English, and in all lowercase" + " letters. No capital letters are allowed." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response is in English and in all lowercase letters.""" + assert isinstance(value, str) + + try: + return value.islower() and langdetect.detect(value) == "en" + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logging.error( + "Unable to detect language for text %s due to %s", value, e + ) # refex: disable=pytotw.037 + return True + + +class CommaChecker(Instruction): + """Checks the response for no commas.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "In your entire response, refrain from the use of any commas." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response does not contain commas.""" + return not re.search(r"\,", value) + + +class CapitalWordFrequencyChecker(Instruction): + """Checks frequency of words with all capital letters.""" + + def build_description( + self, + capital_frequency=None, + capital_relation=None, + ): + """Build the instruction description. + + Args: + capital_frequency: An integer that represents the number of words that + should be in all capital letters. + capital_relation: A string that is 'at least' or 'at most' that refers to + the frequency. + + Returns: + A string representing the instruction description. + """ + self._frequency = capital_frequency + if self._frequency is None: + self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY) + + self._comparison_relation = capital_relation + if capital_relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif capital_relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {capital_relation} is given." + ) + + self._description_pattern = ( + "In your response, words with all capital letters should appear" + " {relation} {frequency} times." + ) + + return self._description_pattern.format( + frequency=self._frequency, relation=self._comparison_relation + ) + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return { + "capital_frequency": self._frequency, + "capital_relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["capital_frequency", "capital_relation"] + + def check_following(self, value): + """Checks the frequency of words with all capital letters.""" + # Hyphenated words will count as one word + words = instructions_util.nltk.word_tokenize(value) + capital_words = [word for word in words if word.isupper()] + + capital_words = len(capital_words) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return capital_words < self._frequency + else: + return capital_words >= self._frequency + + +class QuotationChecker(Instruction): + """Checks response is wrapped with double quotation marks.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Wrap your entire response with double quotation marks." + ) + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response is wrapped with double quotation marks.""" + value = value.strip() + return len(value) > 1 and value[0] == '"' and value[-1] == '"' diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions_registry.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..00d9a1de1985beacead34215952ecf4642d1ea35 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions_registry.py @@ -0,0 +1,168 @@ +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry of all instructions.""" + +from lm_eval.tasks.ifeval import instructions + + +_KEYWORD = "keywords:" + +_LANGUAGE = "language:" + +_LENGTH = "length_constraints:" + +_CONTENT = "detectable_content:" + +_FORMAT = "detectable_format:" + +_MULTITURN = "multi-turn:" + +_COMBINATION = "combination:" + +_STARTEND = "startend:" + +_CHANGE_CASES = "change_case:" + +_PUNCTUATION = "punctuation:" + +INSTRUCTION_DICT = { + _KEYWORD + "existence": instructions.KeywordChecker, + _KEYWORD + "frequency": instructions.KeywordFrequencyChecker, + # TODO(jeffreyzhou): make a proper set of sentences to choose from + # _KEYWORD + "key_sentences": instructions.KeySentenceChecker, + _KEYWORD + "forbidden_words": instructions.ForbiddenWords, + _KEYWORD + "letter_frequency": instructions.LetterFrequencyChecker, + _LANGUAGE + "response_language": instructions.ResponseLanguageChecker, + _LENGTH + "number_sentences": instructions.NumberOfSentences, + _LENGTH + "number_paragraphs": instructions.ParagraphChecker, + _LENGTH + "number_words": instructions.NumberOfWords, + _LENGTH + "nth_paragraph_first_word": instructions.ParagraphFirstWordCheck, + _CONTENT + "number_placeholders": instructions.PlaceholderChecker, + _CONTENT + "postscript": instructions.PostscriptChecker, + _FORMAT + "number_bullet_lists": instructions.BulletListChecker, + # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace + # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, + _FORMAT + "constrained_response": instructions.ConstrainedResponseChecker, + _FORMAT + "number_highlighted_sections": (instructions.HighlightSectionChecker), + _FORMAT + "multiple_sections": instructions.SectionChecker, + # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. + # _FORMAT + "rephrase": instructions.RephraseChecker, + _FORMAT + "json_format": instructions.JsonFormat, + _FORMAT + "title": instructions.TitleChecker, + # TODO(tianjianlu): Re-enable with specific prompts. + # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, + _COMBINATION + "two_responses": instructions.TwoResponsesChecker, + _COMBINATION + "repeat_prompt": instructions.RepeatPromptThenAnswer, + _STARTEND + "end_checker": instructions.EndChecker, + _CHANGE_CASES + "capital_word_frequency": instructions.CapitalWordFrequencyChecker, + _CHANGE_CASES + "english_capital": instructions.CapitalLettersEnglishChecker, + _CHANGE_CASES + "english_lowercase": instructions.LowercaseLettersEnglishChecker, + _PUNCTUATION + "no_comma": instructions.CommaChecker, + _STARTEND + "quotation": instructions.QuotationChecker, +} + +INSTRUCTION_CONFLICTS = { + _KEYWORD + "existence": {_KEYWORD + "existence"}, + _KEYWORD + "frequency": {_KEYWORD + "frequency"}, + # TODO(jeffreyzhou): make a proper set of sentences to choose from + # _KEYWORD + "key_sentences": instructions.KeySentenceChecker, + _KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"}, + _KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"}, + _LANGUAGE + "response_language": { + _LANGUAGE + "response_language", + _FORMAT + "multiple_sections", + _KEYWORD + "existence", + _KEYWORD + "frequency", + _KEYWORD + "forbidden_words", + _STARTEND + "end_checker", + _CHANGE_CASES + "english_capital", + _CHANGE_CASES + "english_lowercase", + }, + _LENGTH + "number_sentences": {_LENGTH + "number_sentences"}, + _LENGTH + "number_paragraphs": { + _LENGTH + "number_paragraphs", + _LENGTH + "nth_paragraph_first_word", + _LENGTH + "number_sentences", + _LENGTH + "nth_paragraph_first_word", + }, + _LENGTH + "number_words": {_LENGTH + "number_words"}, + _LENGTH + "nth_paragraph_first_word": { + _LENGTH + "nth_paragraph_first_word", + _LENGTH + "number_paragraphs", + }, + _CONTENT + "number_placeholders": {_CONTENT + "number_placeholders"}, + _CONTENT + "postscript": {_CONTENT + "postscript"}, + _FORMAT + "number_bullet_lists": {_FORMAT + "number_bullet_lists"}, + # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace + # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, + _FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()), + _FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"}, + _FORMAT + "multiple_sections": { + _FORMAT + "multiple_sections", + _LANGUAGE + "response_language", + _FORMAT + "number_highlighted_sections", + }, + # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. + # _FORMAT + "rephrase": instructions.RephraseChecker, + _FORMAT + "json_format": set(INSTRUCTION_DICT.keys()).difference( + {_KEYWORD + "forbidden_words", _KEYWORD + "existence"} + ), + _FORMAT + "title": {_FORMAT + "title"}, + # TODO(tianjianlu): Re-enable with specific prompts. + # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, + _COMBINATION + "two_responses": set(INSTRUCTION_DICT.keys()).difference( + { + _KEYWORD + "forbidden_words", + _KEYWORD + "existence", + _LANGUAGE + "response_language", + _FORMAT + "title", + _PUNCTUATION + "no_comma", + } + ), + _COMBINATION + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference( + {_KEYWORD + "existence", _FORMAT + "title", _PUNCTUATION + "no_comma"} + ), + _STARTEND + "end_checker": {_STARTEND + "end_checker"}, + _CHANGE_CASES + "capital_word_frequency": { + _CHANGE_CASES + "capital_word_frequency", + _CHANGE_CASES + "english_lowercase", + _CHANGE_CASES + "english_capital", + }, + _CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"}, + _CHANGE_CASES + "english_lowercase": { + _CHANGE_CASES + "english_lowercase", + _CHANGE_CASES + "english_capital", + }, + _PUNCTUATION + "no_comma": {_PUNCTUATION + "no_comma"}, + _STARTEND + "quotation": {_STARTEND + "quotation", _FORMAT + "title"}, +} + + +def conflict_make(conflicts): + """Makes sure if A conflicts with B, B will conflict with A. + + Args: + conflicts: Dictionary of potential conflicts where key is instruction id + and value is set of instruction ids that it conflicts with. + + Returns: + Revised version of the dictionary. All instructions conflict with + themselves. If A conflicts with B, B will conflict with A. + """ + for key in conflicts: + for k in conflicts[key]: + conflicts[k].add(key) + conflicts[key].add(key) + return conflicts diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions_util.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions_util.py new file mode 100644 index 0000000000000000000000000000000000000000..33e0a0a00c54f301334dc1bcd211dd588e6c9529 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/instructions_util.py @@ -0,0 +1,1701 @@ +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility library of instructions.""" + +import functools +import os +import random +import re +from importlib.metadata import version + +import immutabledict +import nltk +from packaging.version import parse as parse_version + + +# Downloading 'punkt' with nltk<3.9 has a remote code vuln. +# see https://github.com/EleutherAI/lm-evaluation-harness/issues/2210 +# and https://github.com/nltk/nltk/issues/3266 +# for more information. +NLTK_MIN_VERSION = "3.9.1" +RANK = os.environ.get("LOCAL_RANK", "0") + + +def download_nltk_resources(): + """Download 'punkt' if not already installed""" + assert (nltk_version := parse_version(version("nltk"))) >= parse_version( + NLTK_MIN_VERSION + ), ( + f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability." + ) + + try: + nltk.data.find("tokenizers/punkt_tab") + except LookupError: + if RANK == "0": + nltk.download("punkt_tab") + print("Downloaded punkt_tab on rank 0") + + +download_nltk_resources() + +WORD_LIST = [ + "western", + "sentence", + "signal", + "dump", + "spot", + "opposite", + "bottom", + "potato", + "administration", + "working", + "welcome", + "morning", + "good", + "agency", + "primary", + "wish", + "responsibility", + "press", + "problem", + "president", + "steal", + "brush", + "read", + "type", + "beat", + "trainer", + "growth", + "lock", + "bone", + "case", + "equal", + "comfortable", + "region", + "replacement", + "performance", + "mate", + "walk", + "medicine", + "film", + "thing", + "rock", + "tap", + "total", + "competition", + "ease", + "south", + "establishment", + "gather", + "parking", + "world", + "plenty", + "breath", + "claim", + "alcohol", + "trade", + "dear", + "highlight", + "street", + "matter", + "decision", + "mess", + "agreement", + "studio", + "coach", + "assist", + "brain", + "wing", + "style", + "private", + "top", + "brown", + "leg", + "buy", + "procedure", + "method", + "speed", + "high", + "company", + "valuable", + "pie", + "analyst", + "session", + "pattern", + "district", + "pleasure", + "dinner", + "swimming", + "joke", + "order", + "plate", + "department", + "motor", + "cell", + "spend", + "cabinet", + "difference", + "power", + "examination", + "engine", + "horse", + "dimension", + "pay", + "toe", + "curve", + "literature", + "bother", + "fire", + "possibility", + "debate", + "activity", + "passage", + "hello", + "cycle", + "background", + "quiet", + "author", + "effect", + "actor", + "page", + "bicycle", + "error", + "throat", + "attack", + "character", + "phone", + "tea", + "increase", + "outcome", + "file", + "specific", + "inspector", + "internal", + "potential", + "staff", + "building", + "employer", + "shoe", + "hand", + "direction", + "garden", + "purchase", + "interview", + "study", + "recognition", + "member", + "spiritual", + "oven", + "sandwich", + "weird", + "passenger", + "particular", + "response", + "reaction", + "size", + "variation", + "a", + "cancel", + "candy", + "exit", + "guest", + "condition", + "fly", + "price", + "weakness", + "convert", + "hotel", + "great", + "mouth", + "mind", + "song", + "sugar", + "suspect", + "telephone", + "ear", + "roof", + "paint", + "refrigerator", + "organization", + "jury", + "reward", + "engineering", + "day", + "possession", + "crew", + "bar", + "road", + "description", + "celebration", + "score", + "mark", + "letter", + "shower", + "suggestion", + "sir", + "luck", + "national", + "progress", + "hall", + "stroke", + "theory", + "offer", + "story", + "tax", + "definition", + "history", + "ride", + "medium", + "opening", + "glass", + "elevator", + "stomach", + "question", + "ability", + "leading", + "village", + "computer", + "city", + "grand", + "confidence", + "candle", + "priest", + "recommendation", + "point", + "necessary", + "body", + "desk", + "secret", + "horror", + "noise", + "culture", + "warning", + "water", + "round", + "diet", + "flower", + "bus", + "tough", + "permission", + "week", + "prompt", + "connection", + "abuse", + "height", + "save", + "corner", + "border", + "stress", + "drive", + "stop", + "rip", + "meal", + "listen", + "confusion", + "girlfriend", + "living", + "relation", + "significance", + "plan", + "creative", + "atmosphere", + "blame", + "invite", + "housing", + "paper", + "drink", + "roll", + "silver", + "drunk", + "age", + "damage", + "smoke", + "environment", + "pack", + "savings", + "influence", + "tourist", + "rain", + "post", + "sign", + "grandmother", + "run", + "profit", + "push", + "clerk", + "final", + "wine", + "swim", + "pause", + "stuff", + "singer", + "funeral", + "average", + "source", + "scene", + "tradition", + "personal", + "snow", + "nobody", + "distance", + "sort", + "sensitive", + "animal", + "major", + "negotiation", + "click", + "mood", + "period", + "arrival", + "expression", + "holiday", + "repeat", + "dust", + "closet", + "gold", + "bad", + "sail", + "combination", + "clothes", + "emphasis", + "duty", + "black", + "step", + "school", + "jump", + "document", + "professional", + "lip", + "chemical", + "front", + "wake", + "while", + "inside", + "watch", + "row", + "subject", + "penalty", + "balance", + "possible", + "adult", + "aside", + "sample", + "appeal", + "wedding", + "depth", + "king", + "award", + "wife", + "blow", + "site", + "camp", + "music", + "safe", + "gift", + "fault", + "guess", + "act", + "shame", + "drama", + "capital", + "exam", + "stupid", + "record", + "sound", + "swing", + "novel", + "minimum", + "ratio", + "machine", + "shape", + "lead", + "operation", + "salary", + "cloud", + "affair", + "hit", + "chapter", + "stage", + "quantity", + "access", + "army", + "chain", + "traffic", + "kick", + "analysis", + "airport", + "time", + "vacation", + "philosophy", + "ball", + "chest", + "thanks", + "place", + "mountain", + "advertising", + "red", + "past", + "rent", + "return", + "tour", + "house", + "construction", + "net", + "native", + "war", + "figure", + "fee", + "spray", + "user", + "dirt", + "shot", + "task", + "stick", + "friend", + "software", + "promotion", + "interaction", + "surround", + "block", + "purpose", + "practice", + "conflict", + "routine", + "requirement", + "bonus", + "hole", + "state", + "junior", + "sweet", + "catch", + "tear", + "fold", + "wall", + "editor", + "life", + "position", + "pound", + "respect", + "bathroom", + "coat", + "script", + "job", + "teach", + "birth", + "view", + "resolve", + "theme", + "employee", + "doubt", + "market", + "education", + "serve", + "recover", + "tone", + "harm", + "miss", + "union", + "understanding", + "cow", + "river", + "association", + "concept", + "training", + "recipe", + "relationship", + "reserve", + "depression", + "proof", + "hair", + "revenue", + "independent", + "lift", + "assignment", + "temporary", + "amount", + "loss", + "edge", + "track", + "check", + "rope", + "estimate", + "pollution", + "stable", + "message", + "delivery", + "perspective", + "mirror", + "assistant", + "representative", + "witness", + "nature", + "judge", + "fruit", + "tip", + "devil", + "town", + "emergency", + "upper", + "drop", + "stay", + "human", + "neck", + "speaker", + "network", + "sing", + "resist", + "league", + "trip", + "signature", + "lawyer", + "importance", + "gas", + "choice", + "engineer", + "success", + "part", + "external", + "worker", + "simple", + "quarter", + "student", + "heart", + "pass", + "spite", + "shift", + "rough", + "lady", + "grass", + "community", + "garage", + "youth", + "standard", + "skirt", + "promise", + "blind", + "television", + "disease", + "commission", + "positive", + "energy", + "calm", + "presence", + "tune", + "basis", + "preference", + "head", + "common", + "cut", + "somewhere", + "presentation", + "current", + "thought", + "revolution", + "effort", + "master", + "implement", + "republic", + "floor", + "principle", + "stranger", + "shoulder", + "grade", + "button", + "tennis", + "police", + "collection", + "account", + "register", + "glove", + "divide", + "professor", + "chair", + "priority", + "combine", + "peace", + "extension", + "maybe", + "evening", + "frame", + "sister", + "wave", + "code", + "application", + "mouse", + "match", + "counter", + "bottle", + "half", + "cheek", + "resolution", + "back", + "knowledge", + "make", + "discussion", + "screw", + "length", + "accident", + "battle", + "dress", + "knee", + "log", + "package", + "it", + "turn", + "hearing", + "newspaper", + "layer", + "wealth", + "profile", + "imagination", + "answer", + "weekend", + "teacher", + "appearance", + "meet", + "bike", + "rise", + "belt", + "crash", + "bowl", + "equivalent", + "support", + "image", + "poem", + "risk", + "excitement", + "remote", + "secretary", + "public", + "produce", + "plane", + "display", + "money", + "sand", + "situation", + "punch", + "customer", + "title", + "shake", + "mortgage", + "option", + "number", + "pop", + "window", + "extent", + "nothing", + "experience", + "opinion", + "departure", + "dance", + "indication", + "boy", + "material", + "band", + "leader", + "sun", + "beautiful", + "muscle", + "farmer", + "variety", + "fat", + "handle", + "director", + "opportunity", + "calendar", + "outside", + "pace", + "bath", + "fish", + "consequence", + "put", + "owner", + "go", + "doctor", + "information", + "share", + "hurt", + "protection", + "career", + "finance", + "force", + "golf", + "garbage", + "aspect", + "kid", + "food", + "boot", + "milk", + "respond", + "objective", + "reality", + "raw", + "ring", + "mall", + "one", + "impact", + "area", + "news", + "international", + "series", + "impress", + "mother", + "shelter", + "strike", + "loan", + "month", + "seat", + "anything", + "entertainment", + "familiar", + "clue", + "year", + "glad", + "supermarket", + "natural", + "god", + "cost", + "conversation", + "tie", + "ruin", + "comfort", + "earth", + "storm", + "percentage", + "assistance", + "budget", + "strength", + "beginning", + "sleep", + "other", + "young", + "unit", + "fill", + "store", + "desire", + "hide", + "value", + "cup", + "maintenance", + "nurse", + "function", + "tower", + "role", + "class", + "camera", + "database", + "panic", + "nation", + "basket", + "ice", + "art", + "spirit", + "chart", + "exchange", + "feedback", + "statement", + "reputation", + "search", + "hunt", + "exercise", + "nasty", + "notice", + "male", + "yard", + "annual", + "collar", + "date", + "platform", + "plant", + "fortune", + "passion", + "friendship", + "spread", + "cancer", + "ticket", + "attitude", + "island", + "active", + "object", + "service", + "buyer", + "bite", + "card", + "face", + "steak", + "proposal", + "patient", + "heat", + "rule", + "resident", + "broad", + "politics", + "west", + "knife", + "expert", + "girl", + "design", + "salt", + "baseball", + "grab", + "inspection", + "cousin", + "couple", + "magazine", + "cook", + "dependent", + "security", + "chicken", + "version", + "currency", + "ladder", + "scheme", + "kitchen", + "employment", + "local", + "attention", + "manager", + "fact", + "cover", + "sad", + "guard", + "relative", + "county", + "rate", + "lunch", + "program", + "initiative", + "gear", + "bridge", + "breast", + "talk", + "dish", + "guarantee", + "beer", + "vehicle", + "reception", + "woman", + "substance", + "copy", + "lecture", + "advantage", + "park", + "cold", + "death", + "mix", + "hold", + "scale", + "tomorrow", + "blood", + "request", + "green", + "cookie", + "church", + "strip", + "forever", + "beyond", + "debt", + "tackle", + "wash", + "following", + "feel", + "maximum", + "sector", + "sea", + "property", + "economics", + "menu", + "bench", + "try", + "language", + "start", + "call", + "solid", + "address", + "income", + "foot", + "senior", + "honey", + "few", + "mixture", + "cash", + "grocery", + "link", + "map", + "form", + "factor", + "pot", + "model", + "writer", + "farm", + "winter", + "skill", + "anywhere", + "birthday", + "policy", + "release", + "husband", + "lab", + "hurry", + "mail", + "equipment", + "sink", + "pair", + "driver", + "consideration", + "leather", + "skin", + "blue", + "boat", + "sale", + "brick", + "two", + "feed", + "square", + "dot", + "rush", + "dream", + "location", + "afternoon", + "manufacturer", + "control", + "occasion", + "trouble", + "introduction", + "advice", + "bet", + "eat", + "kill", + "category", + "manner", + "office", + "estate", + "pride", + "awareness", + "slip", + "crack", + "client", + "nail", + "shoot", + "membership", + "soft", + "anybody", + "web", + "official", + "individual", + "pizza", + "interest", + "bag", + "spell", + "profession", + "queen", + "deal", + "resource", + "ship", + "guy", + "chocolate", + "joint", + "formal", + "upstairs", + "car", + "resort", + "abroad", + "dealer", + "associate", + "finger", + "surgery", + "comment", + "team", + "detail", + "crazy", + "path", + "tale", + "initial", + "arm", + "radio", + "demand", + "single", + "draw", + "yellow", + "contest", + "piece", + "quote", + "pull", + "commercial", + "shirt", + "contribution", + "cream", + "channel", + "suit", + "discipline", + "instruction", + "concert", + "speech", + "low", + "effective", + "hang", + "scratch", + "industry", + "breakfast", + "lay", + "join", + "metal", + "bedroom", + "minute", + "product", + "rest", + "temperature", + "many", + "give", + "argument", + "print", + "purple", + "laugh", + "health", + "credit", + "investment", + "sell", + "setting", + "lesson", + "egg", + "middle", + "marriage", + "level", + "evidence", + "phrase", + "love", + "self", + "benefit", + "guidance", + "affect", + "you", + "dad", + "anxiety", + "special", + "boyfriend", + "test", + "blank", + "payment", + "soup", + "obligation", + "reply", + "smile", + "deep", + "complaint", + "addition", + "review", + "box", + "towel", + "minor", + "fun", + "soil", + "issue", + "cigarette", + "internet", + "gain", + "tell", + "entry", + "spare", + "incident", + "family", + "refuse", + "branch", + "can", + "pen", + "grandfather", + "constant", + "tank", + "uncle", + "climate", + "ground", + "volume", + "communication", + "kind", + "poet", + "child", + "screen", + "mine", + "quit", + "gene", + "lack", + "charity", + "memory", + "tooth", + "fear", + "mention", + "marketing", + "reveal", + "reason", + "court", + "season", + "freedom", + "land", + "sport", + "audience", + "classroom", + "law", + "hook", + "win", + "carry", + "eye", + "smell", + "distribution", + "research", + "country", + "dare", + "hope", + "whereas", + "stretch", + "library", + "if", + "delay", + "college", + "plastic", + "book", + "present", + "use", + "worry", + "champion", + "goal", + "economy", + "march", + "election", + "reflection", + "midnight", + "slide", + "inflation", + "action", + "challenge", + "guitar", + "coast", + "apple", + "campaign", + "field", + "jacket", + "sense", + "way", + "visual", + "remove", + "weather", + "trash", + "cable", + "regret", + "buddy", + "beach", + "historian", + "courage", + "sympathy", + "truck", + "tension", + "permit", + "nose", + "bed", + "son", + "person", + "base", + "meat", + "usual", + "air", + "meeting", + "worth", + "game", + "independence", + "physical", + "brief", + "play", + "raise", + "board", + "she", + "key", + "writing", + "pick", + "command", + "party", + "yesterday", + "spring", + "candidate", + "physics", + "university", + "concern", + "development", + "change", + "string", + "target", + "instance", + "room", + "bitter", + "bird", + "football", + "normal", + "split", + "impression", + "wood", + "long", + "meaning", + "stock", + "cap", + "leadership", + "media", + "ambition", + "fishing", + "essay", + "salad", + "repair", + "today", + "designer", + "night", + "bank", + "drawing", + "inevitable", + "phase", + "vast", + "chip", + "anger", + "switch", + "cry", + "twist", + "personality", + "attempt", + "storage", + "being", + "preparation", + "bat", + "selection", + "white", + "technology", + "contract", + "side", + "section", + "station", + "till", + "structure", + "tongue", + "taste", + "truth", + "difficulty", + "group", + "limit", + "main", + "move", + "feeling", + "light", + "example", + "mission", + "might", + "wait", + "wheel", + "shop", + "host", + "classic", + "alternative", + "cause", + "agent", + "consist", + "table", + "airline", + "text", + "pool", + "craft", + "range", + "fuel", + "tool", + "partner", + "load", + "entrance", + "deposit", + "hate", + "article", + "video", + "summer", + "feature", + "extreme", + "mobile", + "hospital", + "flight", + "fall", + "pension", + "piano", + "fail", + "result", + "rub", + "gap", + "system", + "report", + "suck", + "ordinary", + "wind", + "nerve", + "ask", + "shine", + "note", + "line", + "mom", + "perception", + "brother", + "reference", + "bend", + "charge", + "treat", + "trick", + "term", + "homework", + "bake", + "bid", + "status", + "project", + "strategy", + "orange", + "let", + "enthusiasm", + "parent", + "concentrate", + "device", + "travel", + "poetry", + "business", + "society", + "kiss", + "end", + "vegetable", + "employ", + "schedule", + "hour", + "brave", + "focus", + "process", + "movie", + "illegal", + "general", + "coffee", + "ad", + "highway", + "chemistry", + "psychology", + "hire", + "bell", + "conference", + "relief", + "show", + "neat", + "funny", + "weight", + "quality", + "club", + "daughter", + "zone", + "touch", + "tonight", + "shock", + "burn", + "excuse", + "name", + "survey", + "landscape", + "advance", + "satisfaction", + "bread", + "disaster", + "item", + "hat", + "prior", + "shopping", + "visit", + "east", + "photo", + "home", + "idea", + "father", + "comparison", + "cat", + "pipe", + "winner", + "count", + "lake", + "fight", + "prize", + "foundation", + "dog", + "keep", + "ideal", + "fan", + "struggle", + "peak", + "safety", + "solution", + "hell", + "conclusion", + "population", + "strain", + "alarm", + "measurement", + "second", + "train", + "race", + "due", + "insurance", + "boss", + "tree", + "monitor", + "sick", + "course", + "drag", + "appointment", + "slice", + "still", + "care", + "patience", + "rich", + "escape", + "emotion", + "royal", + "female", + "childhood", + "government", + "picture", + "will", + "sock", + "big", + "gate", + "oil", + "cross", + "pin", + "improvement", + "championship", + "silly", + "help", + "sky", + "pitch", + "man", + "diamond", + "most", + "transition", + "work", + "science", + "committee", + "moment", + "fix", + "teaching", + "dig", + "specialist", + "complex", + "guide", + "people", + "dead", + "voice", + "original", + "break", + "topic", + "data", + "degree", + "reading", + "recording", + "bunch", + "reach", + "judgment", + "lie", + "regular", + "set", + "painting", + "mode", + "list", + "player", + "bear", + "north", + "wonder", + "carpet", + "heavy", + "officer", + "negative", + "clock", + "unique", + "baby", + "pain", + "assumption", + "disk", + "iron", + "bill", + "drawer", + "look", + "double", + "mistake", + "finish", + "future", + "brilliant", + "contact", + "math", + "rice", + "leave", + "restaurant", + "discount", + "sex", + "virus", + "bit", + "trust", + "event", + "wear", + "juice", + "failure", + "bug", + "context", + "mud", + "whole", + "wrap", + "intention", + "draft", + "pressure", + "cake", + "dark", + "explanation", + "space", + "angle", + "word", + "efficiency", + "management", + "habit", + "star", + "chance", + "finding", + "transportation", + "stand", + "criticism", + "flow", + "door", + "injury", + "insect", + "surprise", + "apartment", +] # pylint: disable=line-too-long + +# ISO 639-1 codes to language names. +LANGUAGE_CODES = immutabledict.immutabledict( + { + "en": "English", + "es": "Spanish", + "pt": "Portuguese", + "ar": "Arabic", + "hi": "Hindi", + "fr": "French", + "ru": "Russian", + "de": "German", + "ja": "Japanese", + "it": "Italian", + "bn": "Bengali", + "uk": "Ukrainian", + "th": "Thai", + "ur": "Urdu", + "ta": "Tamil", + "te": "Telugu", + "bg": "Bulgarian", + "ko": "Korean", + "pl": "Polish", + "he": "Hebrew", + "fa": "Persian", + "vi": "Vietnamese", + "ne": "Nepali", + "sw": "Swahili", + "kn": "Kannada", + "mr": "Marathi", + "gu": "Gujarati", + "pa": "Punjabi", + "ml": "Malayalam", + "fi": "Finnish", + } +) + +_ALPHABETS = "([A-Za-z])" +_PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" +_SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" +_STARTERS = r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" +_ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" +_WEBSITES = "[.](com|net|org|io|gov|edu|me)" +_DIGITS = "([0-9])" +_MULTIPLE_DOTS = r"\.{2,}" + + +def split_into_sentences(text): + """Split the text into sentences. + + Args: + text: A string that consists of more than or equal to one sentences. + + Returns: + A list of strings where each string is a sentence. + """ + text = " " + text + " " + text = text.replace("\n", " ") + text = re.sub(_PREFIXES, "\\1", text) + text = re.sub(_WEBSITES, "\\1", text) + text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text) + text = re.sub( + _MULTIPLE_DOTS, + lambda match: "" * len(match.group(0)) + "", + text, + ) + if "Ph.D" in text: + text = text.replace("Ph.D.", "PhD") + text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text) + text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text) + text = re.sub( + _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]", + "\\1\\2\\3", + text, + ) + text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text) + text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text) + text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text) + text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text) + if "”" in text: + text = text.replace(".”", "”.") + if '"' in text: + text = text.replace('."', '".') + if "!" in text: + text = text.replace('!"', '"!') + if "?" in text: + text = text.replace('?"', '"?') + text = text.replace(".", ".") + text = text.replace("?", "?") + text = text.replace("!", "!") + text = text.replace("", ".") + sentences = text.split("") + sentences = [s.strip() for s in sentences] + if sentences and not sentences[-1]: + sentences = sentences[:-1] + return sentences + + +def count_words(text): + """Counts the number of words.""" + tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+") + tokens = tokenizer.tokenize(text) + num_words = len(tokens) + return num_words + + +@functools.lru_cache(maxsize=None) +def _get_sentence_tokenizer(): + return nltk.data.load("nltk:tokenizers/punkt/english.pickle") + + +def count_sentences(text): + """Count the number of sentences.""" + tokenizer = _get_sentence_tokenizer() + tokenized_sentences = tokenizer.tokenize(text) + return len(tokenized_sentences) + + +def generate_keywords(num_keywords): + """Randomly generates a few keywords.""" + return random.sample(WORD_LIST, k=num_keywords) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..985e8d5ae578c484267c7c2d90ee7c896028941f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/ifeval/utils.py @@ -0,0 +1,134 @@ +import dataclasses +from typing import Dict, Optional, Union + +from lm_eval.tasks.ifeval import instructions_registry + + +@dataclasses.dataclass +class InputExample: + key: int + instruction_id_list: list[str] + prompt: str + kwargs: list[Dict[str, Optional[Union[str, int]]]] + + +@dataclasses.dataclass +class OutputExample: + instruction_id_list: list[str] + prompt: str + response: str + follow_all_instructions: bool + follow_instruction_list: list[bool] + + +def test_instruction_following_strict( + inp, + response, +): + """Tests response to see if instructions are followed.""" + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + if response.strip() and instruction.check_following(response): + is_following_list.append(True) + else: + is_following_list.append(False) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def test_instruction_following_loose( + inp, + response, +): + """Tests response for an upper bound for following instructions.""" + r = response.split("\n") + response_remove_first = "\n".join(r[1:]).strip() + response_remove_last = "\n".join(r[:-1]).strip() + response_remove_both = "\n".join(r[1:-1]).strip() + revised_response = response.replace("*", "") + revised_response_remove_first = response_remove_first.replace("*", "") + revised_response_remove_last = response_remove_last.replace("*", "") + revised_response_remove_both = response_remove_both.replace("*", "") + all_responses = [ + response, + revised_response, + response_remove_first, + response_remove_last, + response_remove_both, + revised_response_remove_first, + revised_response_remove_last, + revised_response_remove_both, + ] + instruction_list = inp.instruction_id_list + is_following_list = [] + + for index, instruction_id in enumerate(instruction_list): + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] + instruction = instruction_cls(instruction_id) + + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. + kwargs = {k: v for k, v in inp.kwargs[index].items() if v} + instruction.build_description(**kwargs) + args = instruction.get_instruction_args() + if args and "prompt" in args: + instruction.build_description(prompt=inp.prompt) + + is_following = False + for r in all_responses: + if r.strip() and instruction.check_following(r): + is_following = True + break + + is_following_list.append(is_following) + + return OutputExample( + instruction_id_list=inp.instruction_id_list, + prompt=inp.prompt, + response=response, + follow_all_instructions=all(is_following_list), + follow_instruction_list=is_following_list, + ) + + +def process_results(doc, results): + inp = InputExample( + key=doc["key"], + instruction_id_list=doc["instruction_id_list"], + prompt=doc["prompt"], + kwargs=doc["kwargs"], + ) + response = results[0] + + out_strict = test_instruction_following_strict(inp, response) + out_loose = test_instruction_following_loose(inp, response) + + return { + "prompt_level_strict_acc": out_strict.follow_all_instructions, + "inst_level_strict_acc": out_strict.follow_instruction_list, + "prompt_level_loose_acc": out_loose.follow_all_instructions, + "inst_level_loose_acc": out_loose.follow_instruction_list, + } + + +def agg_inst_level_acc(items): + flat_items = [item for sublist in items for item in sublist] + inst_level_acc = sum(flat_items) / len(flat_items) + return inst_level_acc diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/math500/math500.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/math500/math500.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1fe2f7a38417fe863c1301953be514b618054707 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/math500/math500.yaml @@ -0,0 +1,12 @@ +task: math500 +dataset_path: HuggingFaceH4/MATH-500 +output_type: generate_until +test_split: test +doc_to_text: !function utils.math500_prompt +doc_to_target: "{{answer}}" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/math500/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/math500/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0585298c29c8b5c12ebeaa01dfff572267db601 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/math500/utils.py @@ -0,0 +1,14 @@ +def math500_prompt(doc): + system_prompt = ( + "You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n" + "Respond in the following format:\n" + "\n" + "Your reasoning here\n" + "\n" + "\n" + "\\boxed{...}\n" + "" + ) + + prompt = f"{system_prompt}\n\n{doc['problem']}\n\n" + return prompt diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fd6df44fb76a1a9e017d60afc470200967696f19 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/README.md @@ -0,0 +1,43 @@ +# MBPP + +## Paper +Program Synthesis with Large Language Models +https://arxiv.org/abs/2108.07732 + +This paper explores the limits of the current generation of large language models for program synthesis in general purpose programming languages. We evaluate a collection of such models (with between 244M and 137B parameters) on two new benchmarks, MBPP and MathQA-Python, in both the few-shot and fine-tuning regimes. Our benchmarks are designed to measure the ability of these models to synthesize short Python programs from natural language descriptions. The Mostly Basic Programming Problems (MBPP) dataset contains 974 programming tasks, designed to be solvable by entry-level programmers. The MathQA-Python dataset, a Python version of the MathQA benchmark, contains 23914 problems that evaluate the ability of the models to synthesize code from more complex text. On both datasets, we find that synthesis performance scales log-linearly with model size. Our largest models, even without finetuning on a code dataset, can synthesize solutions to 59.6 percent of the problems from MBPP using few-shot learning with a well-designed prompt. Fine-tuning on a held-out portion of the dataset improves performance by about 10 percentage points across most model sizes. On the MathQA-Python dataset, the largest fine-tuned model achieves 83.8 percent accuracy. Going further, we study the model's ability to engage in dialog about code, incorporating human feedback to improve its solutions. We find that natural language feedback from a human halves the error rate compared to the model's initial prediction. Additionally, we conduct an error analysis to shed light on where these models fall short and what types of programs are most difficult to generate. Finally, we explore the semantic grounding of these models by fine-tuning them to predict the results of program execution. We find that even our best models are generally unable to predict the output of a program given a specific input. + +Homepage: https://github.com/google-research/google-research/tree/master/mbpp + + +## Citation +``` +@article{austin2021program, + title={Program synthesis with large language models}, + author={Austin, Jacob and Odena, Augustus and Nye, Maxwell and Bosma, Maarten and Michalewski, Henryk and Dohan, David and Jiang, Ellen and Cai, Carrie and Terry, Michael and Le, Quoc and others}, + journal={arXiv preprint arXiv:2108.07732}, + year={2021} +} +``` + +### Groups and Tasks + +#### Groups + +* Not part of a group yet. + +#### Tasks + +- `mbpp` + +### Checklist + +For adding novel benchmarks/datasets to the library: +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9819870348baef1f450df6199348b626d59e4793 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp.yaml @@ -0,0 +1,23 @@ +task: mbpp +dataset_path: google-research-datasets/mbpp +dataset_name: full +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "You are an expert Python programmer, and here is your task: {{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}\n[BEGIN]\n" +doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}" +target_delimiter: "" +metric_list: + - metric: !function utils.pass_at_1 + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "[DONE]" + do_sample: false +num_fewshot: 0 +fewshot_config: + sampler: first_n + samples: !function utils.list_fewshot_samples +metadata: + version: 1.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_instruct.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1fac69ebc030494d7529ddd0677576eae2de17b2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_instruct.yaml @@ -0,0 +1,24 @@ +task: mbpp_instruct +dataset_path: google-research-datasets/mbpp +dataset_name: full +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "You are an expert Python programmer, and here is your task: {{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}" +doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}" +gen_prefix: "Here is the completed function:\n```python\n" +target_delimiter: "" +metric_list: + - metric: !function utils.pass_at_1 + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "```" + do_sample: false +num_fewshot: 0 +fewshot_config: + sampler: first_n + samples: !function utils.list_fewshot_samples +metadata: + version: 1.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_plus.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_plus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..133c393c55e175ca958e5861cac91c8ec4f6beed --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_plus.yaml @@ -0,0 +1,5 @@ +include: mbpp.yaml +task: mbpp_plus +dataset_path: evalplus/mbppplus +dataset_name: null +doc_to_text: "You are an expert Python programmer, and here is your task: {{prompt if prompt is defined else text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}\n[BEGIN]\n" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_plus_instruct.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_plus_instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1124cc0629b4e12ecd1764bbd19d282320645f0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/mbpp_plus_instruct.yaml @@ -0,0 +1,5 @@ +include: mbpp_instruct.yaml +task: mbpp_plus_instruct +dataset_path: evalplus/mbppplus +dataset_name: null +doc_to_text: "You are an expert Python programmer, and here is your task: {{prompt if prompt is defined else text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ee7cddc127575bae2ae77530936e8c792b0065 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mbpp/utils.py @@ -0,0 +1,59 @@ +import evaluate as hf_evaluate + + +try: + pass_at_k = hf_evaluate.load("code_eval") + + # run simple test to check code execution is enabled before model generation + test_cases = ["assert add(2, 3)==5"] + candidates = [["def add(a,b): return a*b"]] + results = pass_at_k.compute(references=test_cases, predictions=candidates, k=[1]) +except Exception as e: + raise e + + +def pass_at_1(references, predictions): + print(predictions) + return pass_at_k.compute( + references=references, + predictions=[predictions], + k=[1], + )[0]["pass@1"] + + +def list_fewshot_samples(): + return [ + { + "task_id": 2, + "text": "Write a function to find the similar elements from the given two tuple lists.", + "code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ", + "test_list": [ + "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)", + "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)", + "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)", + ], + "is_fewshot": True, + }, + { + "task_id": 3, + "text": "Write a python function to identify non-prime numbers.", + "code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result", + "test_list": [ + "assert is_not_prime(2) == False", + "assert is_not_prime(10) == True", + "assert is_not_prime(35) == True", + ], + "is_fewshot": True, + }, + { + "task_id": 4, + "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.", + "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums", + "test_list": [ + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]", + ], + "is_fewshot": True, + }, + ] diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4cd78f76eb927db8f059fbba1a2e2bbe5a7ce03f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/README.md @@ -0,0 +1,68 @@ +# MATH +ℹ️ This is the 4-shot variant! +## Paper +Measuring Mathematical Problem Solving With the MATH Dataset +https://arxiv.org/abs/2103.03874 + +Many intellectual endeavors require mathematical problem solving, but this skill remains beyond the capabilities of computers. To measure this ability in machine learning models, we introduce MATH, a new dataset of 12,500 challenging competition mathematics problems. Each problem in MATH has a full step-by-step solution which can be used to teach models to generate answer derivations and explanations. + +NOTE: The few-shot and the generated answer extraction is based on the [Minerva](https://arxiv.org/abs/2206.14858) and exact match equivalence is calculated using the `sympy` library. This requires additional dependencies, which can be installed via the `lm-eval[math]` extra. + +Homepage: https://github.com/hendrycks/math + + +## Citation +``` +@article{hendrycksmath2021, + title={Measuring Mathematical Problem Solving With the MATH Dataset}, + author={Dan Hendrycks and Collin Burns and Saurav Kadavath and Akul Arora and Steven Basart and Eric Tang and Dawn Song and Jacob Steinhardt}, + journal={NeurIPS}, + year={2021} +} + +@misc{2206.14858, +Author = {Aitor Lewkowycz and Anders Andreassen and David Dohan and Ethan Dyer and Henryk Michalewski and Vinay Ramasesh and Ambrose Slone and Cem Anil and Imanol Schlag and Theo Gutman-Solo and Yuhuai Wu and Behnam Neyshabur and Guy Gur-Ari and Vedant Misra}, +Title = {Solving Quantitative Reasoning Problems with Language Models}, +Year = {2022}, +Eprint = {arXiv:2206.14858}, +} +``` + +### Groups and Tasks + +#### Groups + +- `minerva_math` + +#### Tasks + +- `minerva_math_algebra` +- `minerva_math_counting_and_prob` +- `minerva_math_geometry` +- `minerva_math_intermediate_algebra` +- `minerva_math_num_theory` +- `minerva_math_prealgebra` +- `minerva_math_precalc` + +### Checklist + +The checklist is the following: + +For adding novel benchmarks/datasets to the library: +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + * The implementation in the original paper is one where the model is first fine-tuned on the data. They do have a few-shot evaluation for GPT-3, however the few-shot context used here is sourced from [Lewkowycz et al](https://arxiv.org/abs/2206.14858). The achieved accuracy on Llama-2 models is comparable to that provided in the paper, though not identical. + + +If other tasks on this dataset are already supported: +* [x] Is the "Main" variant of this task clearly denoted? +* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [x] Have you noted which, if any, published evaluation setups are matched by this variant? + +### Variant Wishlist + +- [ ] zero-shot variant + +### Changelog +version 2.0: (21-Feb-2025); added math_verify (extraction) metric. For details [see](https://huggingface.co/blog/math_verify_leaderboard) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..daf84f0accbd322f24019aa632cae74988f2cb11 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_algebra.yaml @@ -0,0 +1,32 @@ +tag: + - math_word_problems +task: minerva_math_algebra +dataset_path: EleutherAI/hendrycks_math +process_docs: !function utils.process_docs +dataset_name: algebra +output_type: generate_until +training_split: train +test_split: test +doc_to_text: !function utils.doc_to_text +process_results: !function utils.process_results +doc_to_target: "{{answer if few_shot is undefined else solution}}" +generation_kwargs: + until: + - "Problem:" + do_sample: false + temperature: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + - metric: math_verify + aggregation: mean + higher_is_better: true +num_fewshot: 0 +metadata: + version: 2.0 +dataset_kwargs: + trust_remote_code: true +fewshot_config: + sampler: first_n + samples: !function utils.list_fewshot_samples diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_counting_and_prob.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_counting_and_prob.yaml new file mode 100644 index 0000000000000000000000000000000000000000..688cd711c50d005d5d78ca55116ad333d96161ce --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_counting_and_prob.yaml @@ -0,0 +1,3 @@ +include: minerva_math_algebra.yaml +dataset_name: counting_and_probability +task: minerva_math_counting_and_prob diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_geometry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_geometry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..079ee70e9ed8997f351d1732c0c88dad1e4896de --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_geometry.yaml @@ -0,0 +1,3 @@ +include: minerva_math_algebra.yaml +dataset_name: geometry +task: minerva_math_geometry diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_intermediate_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_intermediate_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b3f063c36e10063dd06be93c290820a787ddd1d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_intermediate_algebra.yaml @@ -0,0 +1,3 @@ +include: minerva_math_algebra.yaml +dataset_name: intermediate_algebra +task: minerva_math_intermediate_algebra diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_num_theory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_num_theory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44f587bce4cce5e4ab80d24b938b88488553d6da --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_num_theory.yaml @@ -0,0 +1,3 @@ +include: minerva_math_algebra.yaml +dataset_name: number_theory +task: minerva_math_num_theory diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_prealgebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_prealgebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..865e2f2c6e5397a07fb473a89f4d8eaf47d3eb52 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_prealgebra.yaml @@ -0,0 +1,3 @@ +include: minerva_math_algebra.yaml +dataset_name: prealgebra +task: minerva_math_prealgebra diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_precalc.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_precalc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..06e63abc7c206b43759217b38cd5db2395e554a9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/minerva_math_precalc.yaml @@ -0,0 +1,3 @@ +include: minerva_math_algebra.yaml +dataset_name: precalculus +task: minerva_math_precalc diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..984ba33f229d624c9fc6036fa8f05e4da9d5cca4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/minerva_math/utils.py @@ -0,0 +1,306 @@ +import logging +import re +import signal +from importlib.metadata import version +from typing import Dict, List, Optional + +import datasets + + +eval_logger = logging.getLogger(__name__) + + +try: + import antlr4 + import sympy + from math_verify import parse, verify + from sympy.parsing.latex import parse_latex + + assert version("antlr4-python3-runtime").startswith("4.11") +except (ModuleNotFoundError, AssertionError) as e: + raise type(e)( + "`sympy`, `math_verify` and `antlr4-python3-runtime==4.11` are required for generating translation task prompt templates. " + "Please install the required packages via pip install lm-eval[math] or pip install -e .[math]" + ) from e + + +# taken from +# https://github.com/wellecks/lm-evaluation-harness/blob/master/lm_eval/tasks/minerva_math.py +def doc_to_text(doc: dict) -> str: + return "Problem:" + "\n" + doc["problem"] + "\n\n" + "Solution:" + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + def _process_doc(doc: dict) -> dict: + out_doc = { + "problem": doc["problem"], + "solution": doc["solution"], + "answer": normalize_final_answer( + remove_boxed(last_boxed_only_string(doc["solution"])) + ), + } + if getattr(doc, "few_shot", None) is not None: + out_doc["few_shot"] = True + return out_doc + + return dataset.map(_process_doc) + + +def list_fewshot_samples() -> list[dict]: + return [ + { + "problem": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", + "solution": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.", + "few_shot": "1", + }, + { + "problem": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", + "solution": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.", + "few_shot": "1", + }, + { + "problem": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", + "solution": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$:\n\\begin{align*}\n30n&=480\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.", + "few_shot": "1", + }, + { + "problem": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero,\nfind $\\frac{a}{b},$ assuming $b$ is nonzero.", + "solution": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.", + "few_shot": "1", + }, + ] + + +def process_results(doc: dict, results: List[str]) -> Dict[str, int]: + candidates = results[0] + + unnormalized_answer = get_unnormalized_answer(candidates) + answer = normalize_final_answer(unnormalized_answer) + + if is_equiv(answer, doc["answer"]): + retval = 1 + else: + retval = 0 + + # math_verify + res = verify(parse(doc["answer"]), parse(candidates)) + mathval = 1 if res else 0 + + results = { + "exact_match": retval, + "math_verify": mathval, + } + return results + + +def last_boxed_only_string(string: str) -> Optional[str]: + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def remove_boxed(s: str) -> str: + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +def is_equiv(x1: str, x2: str) -> bool: + """ + x1 and x2 are normalized latex string + """ + try: + with timeout(seconds=5): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex.errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + eval_logger.debug(f"couldn't parse one of {x1} or {x2}") + return False + + try: + diff = parsed_x1 - parsed_x2 + except TypeError: + eval_logger.debug(f"couldn't subtract {x1} and {x2}") + return False + + try: + if sympy.simplify(diff) == 0: + return True + else: + return False + except ValueError: + eval_logger.debug( + f"Had some trouble simplifying when comparing {x1} and {x2}" + ) + except TimeoutError: + eval_logger.debug(f"Timed out comparing {x1} and {x2}") + return False + except ImportError as e: + eval_logger.error(e) + raise + except Exception as e: + eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}") + return False + + +def get_unnormalized_answer(text: str) -> str: + INVALID_ANSWER = "[invalidanswer]" + end_seq = "I hope it is correct." + text += end_seq + match = re.search( + r"Final Answer: The final answer is(.*?). I hope it is correct.", + text, + ) + if match: + return match.group(1).strip() + else: + return INVALID_ANSWER + + +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "ft", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """ + Normalize a final answer to a quantitative reasoning question. + + Copied character for character from appendix D of Lewkowycz et al. (2022) + """ + final_answer = final_answer.split("=")[-1] + + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a3425d517654a6b93e03ee1bb681e07de18c4016 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/README.md @@ -0,0 +1,73 @@ +# Task-name + +### Paper + +Title: `Measuring Massive Multitask Language Understanding` + +Abstract: `https://arxiv.org/abs/2009.03300` + +`The test covers 57 tasks including elementary mathematics, US history, computer science, law, and more.` + +Homepage: `https://github.com/hendrycks/test` + +Note: The `Flan` variants are derived from [here](https://github.com/jasonwei20/flan-2), and as described in Appendix D.1 of [Scaling Instruction-Finetuned Language Models](https://arxiv.org/abs/2210.11416). + +### Citation + +``` +@article{hendryckstest2021, + title={Measuring Massive Multitask Language Understanding}, + author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, + journal={Proceedings of the International Conference on Learning Representations (ICLR)}, + year={2021} +} + +@article{hendrycks2021ethics, + title={Aligning AI With Shared Human Values}, + author={Dan Hendrycks and Collin Burns and Steven Basart and Andrew Critch and Jerry Li and Dawn Song and Jacob Steinhardt}, + journal={Proceedings of the International Conference on Learning Representations (ICLR)}, + year={2021} +} +``` + +### Groups, Tags, and Tasks + +#### Groups + +* `mmlu`: `Original multiple-choice MMLU benchmark` +* `mmlu_continuation`: `MMLU but with continuation prompts` +* `mmlu_generation`: `MMLU generation` + +MMLU is the original benchmark as implemented by Hendrycks et al. with the choices in context and the answer letters (e.g `A`, `B`, `C`, `D`) in the continuation. +`mmlu_continuation` is a cloze-style variant without the choices in context and the full answer choice in the continuation. +`mmlu_generation` is a generation variant, similar to the original but the LLM is asked to generate the correct answer letter. + + +#### Subgroups + +* `mmlu_stem' +* `mmlu_humanities' +* `mmlu_social_sciences' +* `mmlu_other' + +Subgroup variants are prefixed with the subgroup name, e.g. `mmlu_stem_continuation`. + +### Checklist + +For adding novel benchmarks/datasets to the library: +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: +* [x] Is the "Main" variant of this task clearly denoted? +* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [x] Have you noted which, if any, published evaluation setups are matched by this variant? + +# changelog +ver 1: PR #497 +switch to original implementation + +ver 2: PR #2116 +add missing newline in description. diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/_generate_configs.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/_generate_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..88a7a2c2e63a5066b7f60a0bee8e8839173969e4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/_generate_configs.py @@ -0,0 +1,159 @@ +# noqa +""" +Take in a YAML, and output all "other" splits with this YAML +""" + +import argparse +import logging +import os + +import yaml +from tqdm import tqdm + + +eval_logger = logging.getLogger(__name__) + + +SUBJECTS = { + "abstract_algebra": "stem", + "anatomy": "stem", + "astronomy": "stem", + "business_ethics": "other", + "clinical_knowledge": "other", + "college_biology": "stem", + "college_chemistry": "stem", + "college_computer_science": "stem", + "college_mathematics": "stem", + "college_medicine": "other", + "college_physics": "stem", + "computer_security": "stem", + "conceptual_physics": "stem", + "econometrics": "social_sciences", + "electrical_engineering": "stem", + "elementary_mathematics": "stem", + "formal_logic": "humanities", + "global_facts": "other", + "high_school_biology": "stem", + "high_school_chemistry": "stem", + "high_school_computer_science": "stem", + "high_school_european_history": "humanities", + "high_school_geography": "social_sciences", + "high_school_government_and_politics": "social_sciences", + "high_school_macroeconomics": "social_sciences", + "high_school_mathematics": "stem", + "high_school_microeconomics": "social_sciences", + "high_school_physics": "stem", + "high_school_psychology": "social_sciences", + "high_school_statistics": "stem", + "high_school_us_history": "humanities", + "high_school_world_history": "humanities", + "human_aging": "other", + "human_sexuality": "social_sciences", + "international_law": "humanities", + "jurisprudence": "humanities", + "logical_fallacies": "humanities", + "machine_learning": "stem", + "management": "other", + "marketing": "other", + "medical_genetics": "other", + "miscellaneous": "other", + "moral_disputes": "humanities", + "moral_scenarios": "humanities", + "nutrition": "other", + "philosophy": "humanities", + "prehistory": "humanities", + "professional_accounting": "other", + "professional_law": "humanities", + "professional_medicine": "other", + "professional_psychology": "social_sciences", + "public_relations": "social_sciences", + "security_studies": "social_sciences", + "sociology": "social_sciences", + "us_foreign_policy": "social_sciences", + "virology": "other", + "world_religions": "humanities", +} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--base_yaml_path", required=True) + parser.add_argument("--save_prefix_path", default="mmlu") + parser.add_argument("--cot_prompt_path", default=None) + parser.add_argument("--task_prefix", default="") + parser.add_argument("--group_prefix", default="") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + # get filename of base_yaml so we can `"include": ` it in our "other" YAMLs. + base_yaml_name = os.path.split(args.base_yaml_path)[-1] + with open(args.base_yaml_path, encoding="utf-8") as f: + base_yaml = yaml.full_load(f) + + if args.cot_prompt_path is not None: + import json + + with open(args.cot_prompt_path, encoding="utf-8") as f: + cot_file = json.load(f) + + ALL_CATEGORIES = [] + for subject, category in tqdm(SUBJECTS.items()): + if category not in ALL_CATEGORIES: + ALL_CATEGORIES.append(category) + + if args.cot_prompt_path is not None: + description = cot_file[subject] + else: + description = f"The following are multiple choice questions (with answers) about {' '.join(subject.split('_'))}.\n\n" + + yaml_dict = { + "include": base_yaml_name, + "tag": f"mmlu_{args.task_prefix}_{category}" + if args.task_prefix != "" + else f"mmlu_{category}", + "task": f"mmlu_{args.task_prefix}_{subject}" + if args.task_prefix != "" + else f"mmlu_{subject}", + "task_alias": subject.replace("_", " "), + "dataset_name": subject, + "description": description, + } + + file_save_path = args.save_prefix_path + f"_{subject}.yaml" + eval_logger.info(f"Saving yaml for subset {subject} to {file_save_path}") + with open(file_save_path, "w", encoding="utf-8") as yaml_file: + yaml.dump( + yaml_dict, + yaml_file, + allow_unicode=True, + default_style='"', + ) + + if args.task_prefix != "": + mmlu_subcategories = [ + f"mmlu_{args.task_prefix}_{category}" for category in ALL_CATEGORIES + ] + else: + mmlu_subcategories = [f"mmlu_{category}" for category in ALL_CATEGORIES] + + if args.group_prefix != "": + file_save_path = args.group_prefix + ".yaml" + else: + file_save_path = args.save_prefix_path + ".yaml" + + eval_logger.info(f"Saving benchmark config to {file_save_path}") + with open(file_save_path, "w", encoding="utf-8") as yaml_file: + yaml.dump( + { + "group": f"mmlu_{args.task_prefix}" + if args.task_prefix != "" + else "mmlu", + "task": mmlu_subcategories, + }, + yaml_file, + indent=4, + default_flow_style=False, + ) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/_continuation_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/_continuation_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..264e27a5e8ebde9a203094c7cc9735ecf8ef3993 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/_continuation_template_yaml @@ -0,0 +1,13 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +output_type: multiple_choice +test_split: test +fewshot_split: dev +fewshot_config: + sampler: first_n +doc_to_text: "Question: {{question.strip()}}\nAnswer:" +doc_to_choice: "{{choices}}" +doc_to_target: "{{answer}}" +metadata: + version: 1.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/_mmlu.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c0cabf04b8ac1e1f9c809600214c589cfefbba79 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/_mmlu.yaml @@ -0,0 +1,32 @@ +group: mmlu_continuation +group_alias: mmlu (continuation) +task: + - group: stem + task: + - mmlu_continuation_stem + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: other + task: + - mmlu_continuation_other + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: social sciences + task: + - mmlu_continuation_social_sciences + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: humanities + task: + - mmlu_continuation_humanities + aggregate_metric_list: + - metric: acc + weight_by_size: True +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f4e29c0fb5147d883ee993d95822dde10b69d4e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_abstract_algebra.yaml @@ -0,0 +1,6 @@ +"dataset_name": "abstract_algebra" +"description": "The following are questions (with answers) about abstract\ + \ algebra.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_abstract_algebra" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_anatomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc3de9c4e6679ba4c9f66494c908d99781adf5bb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_anatomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "anatomy" +"description": "The following are questions (with answers) about anatomy.\n\ + \n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_anatomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_astronomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76aabcbfcf13a12e66e1af1daae2811b9b388fc8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_astronomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "astronomy" +"description": "The following are questions (with answers) about astronomy.\n\ + \n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_astronomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e64d0920b9d1ac151712aac84a9e9c3f522c3c9f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_business_ethics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "business_ethics" +"description": "The following are questions (with answers) about business\ + \ ethics.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_business_ethics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_clinical_knowledge.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_clinical_knowledge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e79805df6f73782f25be4a302c738b73ecd2f2a2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_clinical_knowledge.yaml @@ -0,0 +1,6 @@ +"dataset_name": "clinical_knowledge" +"description": "The following are questions (with answers) about clinical\ + \ knowledge.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_clinical_knowledge" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..936f6ffe49245d558c0ef8fdf04b600dc177c375 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_biology" +"description": "The following are questions (with answers) about college\ + \ biology.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_college_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..289364ee44351c3d1bcee1193563babe6abe2a63 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_chemistry" +"description": "The following are questions (with answers) about college\ + \ chemistry.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_college_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c7d3c5696067f09f9a68fdd9c3f7a1002d264128 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_computer_science" +"description": "The following are questions (with answers) about college\ + \ computer science.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_college_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2dbc0932f63c0782e106db5fc27e96da9d816dec --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_mathematics" +"description": "The following are questions (with answers) about college\ + \ mathematics.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_college_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38abd2426f844916087795c4cc04355d8d6c2776 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_medicine" +"description": "The following are questions (with answers) about college\ + \ medicine.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_college_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee6b42584c834a5e92506650ee3aba58ed1cfd66 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_college_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_physics" +"description": "The following are questions (with answers) about college\ + \ physics.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_college_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_computer_security.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ebb487dfbf634d390d2b2f9aa0e31e5a2f68fc6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_computer_security.yaml @@ -0,0 +1,6 @@ +"dataset_name": "computer_security" +"description": "The following are questions (with answers) about computer\ + \ security.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_computer_security" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c554caf07da77e4a9bb0bea9672dfcee4777b91 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_conceptual_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "conceptual_physics" +"description": "The following are questions (with answers) about conceptual\ + \ physics.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_conceptual_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_econometrics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..848ce4e1f0dbff32d304c28f3d60d453e591a30f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_econometrics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "econometrics" +"description": "The following are questions (with answers) about econometrics.\n\ + \n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_econometrics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d71dd16481a2bb5289ef5b713218dae0292bb11a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_electrical_engineering.yaml @@ -0,0 +1,6 @@ +"dataset_name": "electrical_engineering" +"description": "The following are questions (with answers) about electrical\ + \ engineering.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_electrical_engineering" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe8aa09718cb8aef0dad48c21926f7dacc7b8ee9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_elementary_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "elementary_mathematics" +"description": "The following are questions (with answers) about elementary\ + \ mathematics.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_elementary_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb5dbd2e505e3fb4604dd75f2d5fe1a35fce3391 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_formal_logic.yaml @@ -0,0 +1,6 @@ +"dataset_name": "formal_logic" +"description": "The following are questions (with answers) about formal\ + \ logic.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_formal_logic" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_global_facts.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..280a50d2ee229b5f047a02024298474225203e54 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_global_facts.yaml @@ -0,0 +1,6 @@ +"dataset_name": "global_facts" +"description": "The following are questions (with answers) about global\ + \ facts.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_global_facts" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e518a5239a6da013ad31bfca284a3b7096bce840 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_biology" +"description": "The following are questions (with answers) about high\ + \ school biology.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c38d60a7706306b215e156d4c27f05585945f7b4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_chemistry" +"description": "The following are questions (with answers) about high\ + \ school chemistry.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5fe34f7af35456657c1acf40e05b3aaabc7893e8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_computer_science" +"description": "The following are questions (with answers) about high\ + \ school computer science.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..666c2742d1b762c103bbd02ff121676a047fb3e5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_european_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_european_history" +"description": "The following are questions (with answers) about high\ + \ school european history.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_european_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41f6caf3e7f3b762af7c0350ca9a73d39bede2b8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_geography.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_geography" +"description": "The following are questions (with answers) about high\ + \ school geography.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_geography" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e80233dc891e6890a5dec384ed2fbe5b82aca094 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_government_and_politics" +"description": "The following are questions (with answers) about high\ + \ school government and politics.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_government_and_politics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce7fa9d5e3caa8dd3ec8e25172afda5f997b6c0c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_macroeconomics" +"description": "The following are questions (with answers) about high\ + \ school macroeconomics.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_macroeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2598dcb38eb9f8fdacced20c57d62c83dacb8a40 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_mathematics" +"description": "The following are questions (with answers) about high\ + \ school mathematics.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96c414d3c411c6380cf83dca3b7aedc325598220 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_microeconomics" +"description": "The following are questions (with answers) about high\ + \ school microeconomics.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_microeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45ab0a539a02ae322f66db689d8eddf13c8b856a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_physics" +"description": "The following are questions (with answers) about high\ + \ school physics.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48dedf5c5ed94a836e0d802398ab05d7ab7db6ce --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_psychology" +"description": "The following are questions (with answers) about high\ + \ school psychology.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_statistics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_statistics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ee2418c7ff5235c1e31cf381502f5b21db60230 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_statistics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_statistics" +"description": "The following are questions (with answers) about high\ + \ school statistics.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_statistics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a00f16ceba2cfd3f313c8fe0d2df4a43e4bbe23d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_us_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_us_history" +"description": "The following are questions (with answers) about high\ + \ school us history.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_us_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dc4cddf553bf0144b5d4ecc5eabe8efef0cf0367 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_high_school_world_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_world_history" +"description": "The following are questions (with answers) about high\ + \ school world history.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_high_school_world_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_human_aging.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..314edeb6c26c6a6be2d819b7c66e047cd48f8933 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_human_aging.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_aging" +"description": "The following are questions (with answers) about human\ + \ aging.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_human_aging" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a1473819ab4307f1e02024a0828ad9803710a59b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_human_sexuality.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_sexuality" +"description": "The following are questions (with answers) about human\ + \ sexuality.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_human_sexuality" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_international_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_international_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5ea8944bcc109000525b90f26f1d0da914d17437 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_international_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "international_law" +"description": "The following are questions (with answers) about international\ + \ law.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_international_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fca1dda86cc382604ca1bcbc308e0062e08dfa80 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_jurisprudence.yaml @@ -0,0 +1,6 @@ +"dataset_name": "jurisprudence" +"description": "The following are questions (with answers) about jurisprudence.\n\ + \n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_jurisprudence" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1b576f9fb3d0ce1d21e8d7543b56a539300be36a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_logical_fallacies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "logical_fallacies" +"description": "The following are questions (with answers) about logical\ + \ fallacies.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_logical_fallacies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15fc3f4bdf0f34e96149ca2f8dddc90d037e8483 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_machine_learning.yaml @@ -0,0 +1,6 @@ +"dataset_name": "machine_learning" +"description": "The following are questions (with answers) about machine\ + \ learning.\n\n" +"tag": "mmlu_continuation_stem" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_machine_learning" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_management.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..575604e0acf52132d9e489a070d28fd761e739eb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_management.yaml @@ -0,0 +1,6 @@ +"dataset_name": "management" +"description": "The following are questions (with answers) about management.\n\ + \n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_management" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_marketing.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af715bee02cfe813b5f045670c8e46dda258e77d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_marketing.yaml @@ -0,0 +1,6 @@ +"dataset_name": "marketing" +"description": "The following are questions (with answers) about marketing.\n\ + \n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_marketing" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3bf63614168f648497d046f015472497a2ac7553 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_medical_genetics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "medical_genetics" +"description": "The following are questions (with answers) about medical\ + \ genetics.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_medical_genetics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f457800932ec2fba831a1d81e6ca4495816f981f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_miscellaneous.yaml @@ -0,0 +1,6 @@ +"dataset_name": "miscellaneous" +"description": "The following are questions (with answers) about miscellaneous.\n\ + \n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_miscellaneous" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0df1392d5baceb1a3dda1464acbb0b025a8428e8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_moral_disputes.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_disputes" +"description": "The following are questions (with answers) about moral\ + \ disputes.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_moral_disputes" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bea5e514b85a6ed83026a6fe9d399f92eb59ea99 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_moral_scenarios.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_scenarios" +"description": "The following are questions (with answers) about moral\ + \ scenarios.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_moral_scenarios" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_nutrition.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8db80340b2a9984cb8c3e41766e3f0e89af8f252 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_nutrition.yaml @@ -0,0 +1,6 @@ +"dataset_name": "nutrition" +"description": "The following are questions (with answers) about nutrition.\n\ + \n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_nutrition" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..165de6c90ba1d4756c39e2f5605226dbeb86e314 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_philosophy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "philosophy" +"description": "The following are questions (with answers) about philosophy.\n\ + \n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_philosophy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_prehistory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_prehistory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..02c4ee7f8af1856f498b7a55c83e085782e36666 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_prehistory.yaml @@ -0,0 +1,6 @@ +"dataset_name": "prehistory" +"description": "The following are questions (with answers) about prehistory.\n\ + \n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_prehistory" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb36a82b9c043b519379626f2d3618efdda9907b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_accounting.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_accounting" +"description": "The following are questions (with answers) about professional\ + \ accounting.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_professional_accounting" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac9f2592f41a2bcae43da174d2eb969cf1805251 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_law" +"description": "The following are questions (with answers) about professional\ + \ law.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_professional_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..328c128377609327abe0460e2d4ab6af716d02c3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_medicine" +"description": "The following are questions (with answers) about professional\ + \ medicine.\n\n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_professional_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0cca5bde048a23367aa2ccebc893e9fa71996d98 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_professional_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_psychology" +"description": "The following are questions (with answers) about professional\ + \ psychology.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_professional_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_public_relations.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..700c407c2377d8d4d83bbf88d8f7a003a2e2900d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_public_relations.yaml @@ -0,0 +1,6 @@ +"dataset_name": "public_relations" +"description": "The following are questions (with answers) about public\ + \ relations.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_public_relations" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_security_studies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_security_studies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f5ef99e0f8fe8c98bc9994757d9cc6617e3550e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_security_studies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "security_studies" +"description": "The following are questions (with answers) about security\ + \ studies.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_security_studies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_sociology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e78621aaa547b419f4133b94ce8dcba00c407f5c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_sociology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "sociology" +"description": "The following are questions (with answers) about sociology.\n\ + \n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_sociology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_us_foreign_policy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_us_foreign_policy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..989bb29aa095e83c2744011775864ef27258ca28 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_us_foreign_policy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "us_foreign_policy" +"description": "The following are questions (with answers) about us\ + \ foreign policy.\n\n" +"tag": "mmlu_continuation_social_sciences" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_us_foreign_policy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_virology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c938190bdd755f411914905d5309daa6938f313 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_virology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "virology" +"description": "The following are questions (with answers) about virology.\n\ + \n" +"tag": "mmlu_continuation_other" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_virology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_world_religions.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_world_religions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f707670066d3f2db4554221a12a3983e2d8febf5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/continuation/mmlu_world_religions.yaml @@ -0,0 +1,6 @@ +"dataset_name": "world_religions" +"description": "The following are questions (with answers) about world\ + \ religions.\n\n" +"tag": "mmlu_continuation_humanities" +"include": "_continuation_template_yaml" +"task": "mmlu_continuation_world_religions" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_default_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_default_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed0e70536b94b9d2127c2e02999d34cd6d0c3943 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_default_template_yaml @@ -0,0 +1,17 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +test_split: test +fewshot_split: dev +fewshot_config: + sampler: first_n +output_type: multiple_choice +doc_to_text: "{{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nAnswer:" +doc_to_choice: ["A", "B", "C", "D"] +doc_to_target: answer +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..55099c6f16febd89270ad022abe181bf8ccd708e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu.yaml @@ -0,0 +1,11 @@ +group: mmlu +task: + - mmlu_stem + - mmlu_other + - mmlu_social_sciences + - mmlu_humanities +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_humanities.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_humanities.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7156e2230f09b461b8e783db323b9ee2d8023192 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_humanities.yaml @@ -0,0 +1,9 @@ +group: mmlu_humanities +group_alias: humanities +task: + - mmlu_humanities_tasks +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_other.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_other.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79025cec0c639a37872287ecb5ae5c444dce7478 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_other.yaml @@ -0,0 +1,9 @@ +group: mmlu_other +group_alias: other +task: + - mmlu_other_tasks +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_social_sciences.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_social_sciences.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fab1ec2c1416bc644c8723bdb18905dff9c00040 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_social_sciences.yaml @@ -0,0 +1,9 @@ +group: mmlu_social_sciences +group_alias: social sciences +task: + - mmlu_social_sciences_tasks +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_stem.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_stem.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cda82eff10a03afe1a05fd8a1368cf3a7c63dcd8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/_mmlu_stem.yaml @@ -0,0 +1,9 @@ +group: mmlu_stem +group_alias: stem +task: + - mmlu_stem_tasks +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dcde12cb4c5566567482e095c87860f1c6179473 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_abstract_algebra.yaml @@ -0,0 +1,7 @@ +"dataset_name": "abstract_algebra" +"description": "The following are multiple choice questions (with answers) about abstract\ + \ algebra.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_abstract_algebra" +"task_alias": "abstract_algebra" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_anatomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5fef7490dd31872f2ed9dcde5c1e817e910b5e39 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_anatomy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "anatomy" +"description": "The following are multiple choice questions (with answers) about anatomy.\n\ + \n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_anatomy" +"task_alias": "anatomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_astronomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..660f07476dfdd115fc0b8d5f04c685b23857cc33 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_astronomy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "astronomy" +"description": "The following are multiple choice questions (with answers) about astronomy.\n\ + \n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_astronomy" +"task_alias": "astronomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a0f1b1c2dcd802effdf589d4f85b412593dfb622 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_business_ethics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "business_ethics" +"description": "The following are multiple choice questions (with answers) about business\ + \ ethics.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_business_ethics" +"task_alias": "business_ethics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_clinical_knowledge.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_clinical_knowledge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c72b71648df5a690963c95180a76f7ad0a495d4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_clinical_knowledge.yaml @@ -0,0 +1,7 @@ +"dataset_name": "clinical_knowledge" +"description": "The following are multiple choice questions (with answers) about clinical\ + \ knowledge.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_clinical_knowledge" +"task_alias": "clinical_knowledge" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ddfd713aa0581b36fdad44da4f80e5b500c47154 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_biology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_biology" +"description": "The following are multiple choice questions (with answers) about college\ + \ biology.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_college_biology" +"task_alias": "college_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..388c3a91bed8ffb7645e0e7f23fb0a81117503cc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_chemistry.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_chemistry" +"description": "The following are multiple choice questions (with answers) about college\ + \ chemistry.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_college_chemistry" +"task_alias": "college_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3f692423abfbf036fc0347fdfbb2642a6d16c39 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_computer_science.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_computer_science" +"description": "The following are multiple choice questions (with answers) about college\ + \ computer science.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_college_computer_science" +"task_alias": "college_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08a9628af175edb897c7f6d88b96d4969fccad29 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_mathematics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_mathematics" +"description": "The following are multiple choice questions (with answers) about college\ + \ mathematics.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_college_mathematics" +"task_alias": "college_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35197a2a1885f7daf30209d4309dd059243260a8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_medicine.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_medicine" +"description": "The following are multiple choice questions (with answers) about college\ + \ medicine.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_college_medicine" +"task_alias": "college_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b5017afac65e0acf080a9df84098a1f21681833 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_college_physics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_physics" +"description": "The following are multiple choice questions (with answers) about college\ + \ physics.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_college_physics" +"task_alias": "college_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_computer_security.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f9b42820f7f7196c6d02922337eaedb7ede5388 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_computer_security.yaml @@ -0,0 +1,7 @@ +"dataset_name": "computer_security" +"description": "The following are multiple choice questions (with answers) about computer\ + \ security.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_computer_security" +"task_alias": "computer_security" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af61a7e1579ac8613b5535e15a57adc629e2d571 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_conceptual_physics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "conceptual_physics" +"description": "The following are multiple choice questions (with answers) about conceptual\ + \ physics.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_conceptual_physics" +"task_alias": "conceptual_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_econometrics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..609c20af2acbdd7ef36104dc97db97a40bfca6a5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_econometrics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "econometrics" +"description": "The following are multiple choice questions (with answers) about econometrics.\n\ + \n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_econometrics" +"task_alias": "econometrics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8fa2137ad05a14e32d5d7e8973d6bc9c18d1a555 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_electrical_engineering.yaml @@ -0,0 +1,7 @@ +"dataset_name": "electrical_engineering" +"description": "The following are multiple choice questions (with answers) about electrical\ + \ engineering.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_electrical_engineering" +"task_alias": "electrical_engineering" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d15f6d5ae88b6edf0bba2298ffaacbd4d103aedd --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_elementary_mathematics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "elementary_mathematics" +"description": "The following are multiple choice questions (with answers) about elementary\ + \ mathematics.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_elementary_mathematics" +"task_alias": "elementary_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee2fc2f61073dc11f6f745eaf8927ab70aadad3f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_formal_logic.yaml @@ -0,0 +1,7 @@ +"dataset_name": "formal_logic" +"description": "The following are multiple choice questions (with answers) about formal\ + \ logic.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_formal_logic" +"task_alias": "formal_logic" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_global_facts.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b27ddefd25be9c6695900ce6d290a811b68356df --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_global_facts.yaml @@ -0,0 +1,7 @@ +"dataset_name": "global_facts" +"description": "The following are multiple choice questions (with answers) about global\ + \ facts.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_global_facts" +"task_alias": "global_facts" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22bc47943f0f66614f79cd0de5e7614afa1f08d5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_biology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_biology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school biology.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_biology" +"task_alias": "high_school_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a25617cbd821411e6f0ca9fac853c76b7adb319 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_chemistry.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_chemistry" +"description": "The following are multiple choice questions (with answers) about high\ + \ school chemistry.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_chemistry" +"task_alias": "high_school_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad4c7d312c7e8f6517d308e6ffeb635a354b843e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_computer_science.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_computer_science" +"description": "The following are multiple choice questions (with answers) about high\ + \ school computer science.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_computer_science" +"task_alias": "high_school_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c51bbdd7aa87b39da8145f8ea45f6fe13d17623 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_european_history.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_european_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school european history.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_european_history" +"task_alias": "high_school_european_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aad87f1ad57a48102d7807a7a3fd75af86755912 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_geography.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_geography" +"description": "The following are multiple choice questions (with answers) about high\ + \ school geography.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_geography" +"task_alias": "high_school_geography" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b93b363d658357619eaf907f8d04af339c22a12 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_government_and_politics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school government and politics.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_government_and_politics" +"task_alias": "high_school_government_and_politics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a08c579d1480ab592917d5a6673e63cf09198417 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_macroeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school macroeconomics.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_macroeconomics" +"task_alias": "high_school_macroeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f1b6d70e022414b7d370635daa49e3a9a8649c2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_mathematics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_mathematics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school mathematics.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_mathematics" +"task_alias": "high_school_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac4f65dad5783bf23c50d9a39e912fe797a047e6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_microeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school microeconomics.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_microeconomics" +"task_alias": "high_school_microeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8c449aa1b5bb48d6899c328c82c44ee3ae3ef24 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_physics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_physics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school physics.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_physics" +"task_alias": "high_school_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..47ba836c71b2be9759bd9fe48dd0cb687ef08636 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_psychology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_psychology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school psychology.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_psychology" +"task_alias": "high_school_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_statistics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_statistics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef5bdd7cf1577a7ba9f3365643c5e56b21c8a77e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_statistics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_statistics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school statistics.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_statistics" +"task_alias": "high_school_statistics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ececdb0ab921bdc24b8aac41979a93d35670d0c6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_us_history.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_us_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school us history.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_us_history" +"task_alias": "high_school_us_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af34c8ddbe51abc0f44baff2bf8087b4c749825f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_high_school_world_history.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_world_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school world history.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_high_school_world_history" +"task_alias": "high_school_world_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_human_aging.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ca720be7c7d757c579e4563cb805dc36a6dcc6d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_human_aging.yaml @@ -0,0 +1,7 @@ +"dataset_name": "human_aging" +"description": "The following are multiple choice questions (with answers) about human\ + \ aging.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_human_aging" +"task_alias": "human_aging" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2acddd1e4ec1d85a7475202d43f5917abb085684 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_human_sexuality.yaml @@ -0,0 +1,7 @@ +"dataset_name": "human_sexuality" +"description": "The following are multiple choice questions (with answers) about human\ + \ sexuality.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_human_sexuality" +"task_alias": "human_sexuality" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_international_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_international_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9fb2a162aab92931f8b560ce0e76155fbc9bb675 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_international_law.yaml @@ -0,0 +1,7 @@ +"dataset_name": "international_law" +"description": "The following are multiple choice questions (with answers) about international\ + \ law.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_international_law" +"task_alias": "international_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c62a911ff5d849651d8c9e09feb34847846d147 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_jurisprudence.yaml @@ -0,0 +1,7 @@ +"dataset_name": "jurisprudence" +"description": "The following are multiple choice questions (with answers) about jurisprudence.\n\ + \n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_jurisprudence" +"task_alias": "jurisprudence" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..adf8821e9a8ac9d80f1cfb5c6af5b74a63efda27 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_logical_fallacies.yaml @@ -0,0 +1,7 @@ +"dataset_name": "logical_fallacies" +"description": "The following are multiple choice questions (with answers) about logical\ + \ fallacies.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_logical_fallacies" +"task_alias": "logical_fallacies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d846f96084a8cba059348a90d800a86b92ba09c2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_machine_learning.yaml @@ -0,0 +1,7 @@ +"dataset_name": "machine_learning" +"description": "The following are multiple choice questions (with answers) about machine\ + \ learning.\n\n" +"tag": "mmlu_stem_tasks" +"include": "_default_template_yaml" +"task": "mmlu_machine_learning" +"task_alias": "machine_learning" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_management.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7dff834ef804039858b6955155a8338dd11b30b3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_management.yaml @@ -0,0 +1,7 @@ +"dataset_name": "management" +"description": "The following are multiple choice questions (with answers) about management.\n\ + \n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_management" +"task_alias": "management" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_marketing.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ef004988965c41ff075f2f976b98dca4657ca04 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_marketing.yaml @@ -0,0 +1,7 @@ +"dataset_name": "marketing" +"description": "The following are multiple choice questions (with answers) about marketing.\n\ + \n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_marketing" +"task_alias": "marketing" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..989fb2c1aea91035421e49c7a11293c48ffec0bc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_medical_genetics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "medical_genetics" +"description": "The following are multiple choice questions (with answers) about medical\ + \ genetics.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_medical_genetics" +"task_alias": "medical_genetics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7bb68bc2eb0f55b784943bd18296aabe3b86a31 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_miscellaneous.yaml @@ -0,0 +1,7 @@ +"dataset_name": "miscellaneous" +"description": "The following are multiple choice questions (with answers) about miscellaneous.\n\ + \n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_miscellaneous" +"task_alias": "miscellaneous" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..348d21403f06669e198146286b83e227fbde5a16 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_moral_disputes.yaml @@ -0,0 +1,7 @@ +"dataset_name": "moral_disputes" +"description": "The following are multiple choice questions (with answers) about moral\ + \ disputes.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_moral_disputes" +"task_alias": "moral_disputes" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3762ee1200848439f08a3c69703af4cffb3a9d74 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_moral_scenarios.yaml @@ -0,0 +1,7 @@ +"dataset_name": "moral_scenarios" +"description": "The following are multiple choice questions (with answers) about moral\ + \ scenarios.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_moral_scenarios" +"task_alias": "moral_scenarios" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_nutrition.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..55f8ca01ff42a296c07d8fd2e2ccda373d91775b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_nutrition.yaml @@ -0,0 +1,7 @@ +"dataset_name": "nutrition" +"description": "The following are multiple choice questions (with answers) about nutrition.\n\ + \n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_nutrition" +"task_alias": "nutrition" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5331c812ef70cb0123d754835fabde16ce330245 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_philosophy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "philosophy" +"description": "The following are multiple choice questions (with answers) about philosophy.\n\ + \n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_philosophy" +"task_alias": "philosophy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_prehistory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_prehistory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0b4ff970a10b7be9ab08527124ea236227b60428 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_prehistory.yaml @@ -0,0 +1,7 @@ +"dataset_name": "prehistory" +"description": "The following are multiple choice questions (with answers) about prehistory.\n\ + \n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_prehistory" +"task_alias": "prehistory" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27b2ec9b9b70e00616d2560c3a8b1259781e8cfb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_accounting.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_accounting" +"description": "The following are multiple choice questions (with answers) about professional\ + \ accounting.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_professional_accounting" +"task_alias": "professional_accounting" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..07c36f1c38d46a513359d80284ead794dd72b7bd --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_law.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_law" +"description": "The following are multiple choice questions (with answers) about professional\ + \ law.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_professional_law" +"task_alias": "professional_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c5754bf379cfd884ad837243105a49e3e28d386 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_medicine.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_medicine" +"description": "The following are multiple choice questions (with answers) about professional\ + \ medicine.\n\n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_professional_medicine" +"task_alias": "professional_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0c0608ef6860edb4b8492402c674a7efda2070f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_professional_psychology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_psychology" +"description": "The following are multiple choice questions (with answers) about professional\ + \ psychology.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_professional_psychology" +"task_alias": "professional_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_public_relations.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43b675bdfd088bb7e651eece031198b5c0fb8ab3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_public_relations.yaml @@ -0,0 +1,7 @@ +"dataset_name": "public_relations" +"description": "The following are multiple choice questions (with answers) about public\ + \ relations.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_public_relations" +"task_alias": "public_relations" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_security_studies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_security_studies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b02342d95ede5148ee8b0aeb9e4ad4fb7dd05938 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_security_studies.yaml @@ -0,0 +1,7 @@ +"dataset_name": "security_studies" +"description": "The following are multiple choice questions (with answers) about security\ + \ studies.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_security_studies" +"task_alias": "security_studies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_sociology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49fa11620fb7147752328a484d56f8ead64c4387 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_sociology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "sociology" +"description": "The following are multiple choice questions (with answers) about sociology.\n\ + \n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_sociology" +"task_alias": "sociology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_us_foreign_policy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_us_foreign_policy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc4335e9eace7816ba112e4f55912223444d4c1f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_us_foreign_policy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "us_foreign_policy" +"description": "The following are multiple choice questions (with answers) about us\ + \ foreign policy.\n\n" +"tag": "mmlu_social_sciences_tasks" +"include": "_default_template_yaml" +"task": "mmlu_us_foreign_policy" +"task_alias": "us_foreign_policy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_virology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f8bc114c3ce7437ad0fb413a69a859f69bcbf99 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_virology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "virology" +"description": "The following are multiple choice questions (with answers) about virology.\n\ + \n" +"tag": "mmlu_other_tasks" +"include": "_default_template_yaml" +"task": "mmlu_virology" +"task_alias": "virology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_world_religions.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_world_religions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b90972c7031c30d89beea835f70aab7cf45cce81 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/default/mmlu_world_religions.yaml @@ -0,0 +1,7 @@ +"dataset_name": "world_religions" +"description": "The following are multiple choice questions (with answers) about world\ + \ religions.\n\n" +"tag": "mmlu_humanities_tasks" +"include": "_default_template_yaml" +"task": "mmlu_world_religions" +"task_alias": "world_religions" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_cot_prompts.json b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_cot_prompts.json new file mode 100644 index 0000000000000000000000000000000000000000..c374b19d03391e61021af6640558a6de8853d7b0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_cot_prompts.json @@ -0,0 +1 @@ +{"abstract_algebra": "The following are multiple choice questions (with answers) about abstract algebra.\n\nQ: Statement 1 | Every element of a group generates a cyclic subgroup of the group. Statement 2 | The symmetric group S_10 has 10 elements.\n(A) True, True (B) False, False (C) True, False (D) False, True\nA: Let's think step by step. A cyclic group is a group that is generated by a single element. Hence a subgroup generated by a single element of a group is cyclic and Statement 1 is True. The answer is (C).\n\nQ: The symmetric group $S_n$ has $\nactorial{n}$ elements, hence it is not true that $S_{10}$ has 10 elements.\nFind the characteristic of the ring 2Z.\n(A) 0 (B) 3 (C) 12 (D) 30\nA: Let's think step by step. A characteristic of a ring is R is $n$ if the statement $ka = 0$ for all $a\\in 2Z$ implies that $k$ is a multiple of $n$. Assume that $ka = 0$ for all $a\\in 2Z$ for some $k$. In particular $2k = 0$. Hence $k=0$ and $n=0$. The answer is (A).\n\nQ: Statement 1| Every function from a finite set onto itself must be one to one. Statement 2 | Every subgroup of an abelian group is abelian.\n(A) True, True (B) False, False (C) True, False (D) False, True\nA: Let's think step by step. Statement 1 is true. Let $S$ be a finite set. If $f:S \nightarrow S$ is a onto function, then $|S| = |f(S)|$. If $f$ was not one to one, then for finite domain $S$ the image would have less than $S$ elements, a contradiction.\nStatement 2 is true. Let $G$ be an abelian group and $H$ be a subgroup of $G$. We need to show that $H$ is abelian. Let $a,b \\in H$. Then $a,b \\in G$ and $ab=ba$. Since $G$ is abelian, $ab=ba$. Since $H$ is a subgroup of $G$, $ab \\in H$. Therefore, $ab=ba$ and $H$ is abelian. The answer is (A).\n\nQ: Statement 1 | If aH is an element of a factor group, then |aH| divides |a|. Statement 2 | If H and K are subgroups of G then HK is a subgroup of G.\n(A) True, True (B) False, False (C) True, False (D) False, True\nA: Let's think step by step. Statement 2 is false. Let $H$ be a subgroup of $S_3$ generated by the cycle $(1,2)$ and $K$ be a subgroup of $S_3$ generated by the cycle $(1,3)$. Both $H$ and $K$ have two elements, the generators and the identity. However $HK$ contains cycles (1,2), (1,3) and (2,3,1), but the inverse of (2,3,1) is (2,1,3) and it does not belong to HK, hence HK is not a subgroup. The answer is (B).\n\nQ: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\n(A) 0 (B) 1 (C) 2 (D) 3\nA: Let's think step by step. Z_3[x]/(x^2 + c) is a field if and only if x^2 + c does not have roots in Z_3. That is x^2 + c != 0 for every x in Z_3. If c = 0, then x^2 + c = x^2 has root 0. If c = 1 then x^2 + c = x^2 + 1 = 0 + 1 for x = 0, 1 + 1 = 2 for x = 1 and 1 + 1 = 2 for x = 2, hence x^2 + 1 does not have any roots. For c = 2 the polynomial x^2 + 2 has two roots at x = 1 and x = 2. Hence Z_3[x]/(x^2 + c) is a field if and only if c = 1. The answer is (B).\n\n", "anatomy": "The following are multiple choice questions (with answers) about anatomy.\n\nQ: Which of the following is the body cavity that contains the pituitary gland?\n(A) Abdominal (B) Cranial (C) Pleural (D) Spinal\nA: Let's think step by step. We refer to Wikipedia articles on anatomy for help. Let\u2019s solve this problem step by step. The pituitary gland is the major endocrine gland attached to the base of the brain, and it is contained in the Cranial cavity. The answer is (B).\n\nQ: Which of these branches of the trigeminal nerve contain somatic motor processes?\n(A) The supraorbital nerve (B) The infraorbital nerve (C) The mental nerve (D) None of the above\nA: Let's think step by step. We refer to Wikipedia articles on anatomy for help. Let\u2019s solve this problem step by step. \nWe know the following: (A) The supraorbital nerve (also known as the frontal nerve) is the largest branch of the ophthalmic nerve and branch of ophthalmic division of the trigeminal nerve. (B) The infraorbital nerve is a branch of the maxillary division of the trigeminal nerve. (C) The mental nerve is a branch of the mandibular division of the trigeminal nerve. Because all these nerves are purely sensory nerves and do not contain any somatic motor processes. Therefore, the answer should be none of the above, which is (D). The answer is (D).\n\nQ: In Angle's Class II Div 2 occlusion there is\n(A) excess overbite of the upper lateral incisors. (B) negative overjet of the upper central incisors. (C) excess overjet of the upper lateral incisors. (D) excess overjet of the upper central incisors.\nA: Let's think step by step. We refer to Wikipedia articles on anatomy for help. Let\u2019s solve this problem step by step. This is a question related to anatomy and orthodontics. Excess overjet is associated with Class II occlusions; therefore, we can safely eliminate (B) from the list, as negative overjet is often associated with Class III occlusions. Now, we need to determine the location of the excess overjet, and that would be the upper (maxillary) lateral incisors. Only (C) has the correct information. The answer is (C).\n\nQ: The pleura\n(A) have no sensory innervation. (B) are separated by a 2 mm space. (C) extend into the neck. (D) are composed of respiratory epithelium.\nA: Let's think step by step. We refer to Wikipedia articles on anatomy for help. Let\u2019s solve this problem step by step. First, recall that the pleura refers to the thin layer of tissue that covers the lungs and lines the interior wall of the chest cavity. Now, let\u2019s look at each option:\nOption (A): \u201cThe pleura have no sensory innervation.\u201d This information is not correct. The pleura do have a sensory innervation.\nOption (B): \u201cThe pleura are separated by a 2 mm space.\u201d This information is not correct. There is a very thin \u201cpotential\u201d space between the layers of the pleura; however, it is typically filled with serous pleural fluid. \nOption (C): \u201cThe pleura extend into the neck.\u201d This information is actuakky true. The cervical pleura, also known as the dome of the pleuradome of the pleura, lines the extendsiton of the pleural cavity into the neck.\nOption (D): \u201cThe pleura are composed of respiratory epithelium.\u201d This information is not correct. The pleaura are composed of connective tissue (CT).\nBecause (A), (B), and (D) are all incorrect, (D) is the only correct answer. The answer is (C).\n\nQ: What is the embryological origin of the hyoid bone?\n(A) The first pharyngeal arch (B) The first and second pharyngeal arches (C) The second pharyngeal arch (D) The second and third pharyngeal arches\nA: Let's think step by step. We refer to Wikipedia articles on anatomy for help. Let\u2019s solve this problem step by step. The hyoid bone, which is also known as the hyooid, is a a small U-shaped bone located in the anterior neck. In its resting position, it lies between the ase of the mandible and the third cervical vertebrae. We know that the second and the third pharyngeal arches give rise to the horns of the hyoid bone; therefore, the embryological origin of the hyoid bone are the second and the third pharyngeal arches\u2014this information is covered in the last option (D). Therefore, we conclude that (D) must be the correct answer. The answer is (D).\n\n", "astronomy": "The following are multiple choice questions (with answers) about astronomy.\n\nQ: Where do most short-period comets come from and how do we know?\n(A) The Kuiper belt; short period comets tend to be in the plane of the solar system just like the Kuiper belt. (B) The Kuiper belt; short period comets tend to come from random directions indicating a spherical distribution of comets called the Kuiper belt. (C) The asteroid belt; short period comets have orbital periods similar to asteroids like Vesta and are found in the plane of the solar system just like the asteroid belt. (D) The Oort cloud; short period comets tend to be in the plane of the solar system just like the Oort cloud.\nA: Let's think step by step. Most short-period comets come from the Kuiper belt, and we know because short period coments tend to be in the plane of the solar system, just like the Kuiper belt is. The answer is (A).\n\nQ: You are pushing a truck along a road. Would it be easier to accelerate this truck on Mars? Why? (Assume there is no friction)\n(A) It would be harder since the truck is heavier on Mars. (B) It would be easier since the truck is lighter on Mars. (C) It would be harder since the truck is lighter on Mars. (D) It would be the same no matter where you are.\nA: Let's think step by step. If we assume that there is no friction, the force needed to accelerate the truck is by Newton\u2019s second law only dependent on the mass of the truck. Hence (A), (B) and (C) are incorrect since it doesn\u2019t matter that it\u2019s on Mars, and (D) is the correct answer. The answer is (D).\n\nQ: Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?\n(A) 10000 times more (B) 100 times more (C) 1000 times more (D) 10 times more\nA: Let's think step by step. The amount of light is proportional to the aperture area $A = \\pi D^2/4$ for a lens with diameter $D$, so the relative amounts of light between the eye with diameter 5mm and the telescope with diameter 50mm is $(50 cm)^2/(5mm)^2 = 10000$. The answer is (A).\n\nQ: Why isn't there a planet where the asteroid belt is located?\n(A) A planet once formed here but it was broken apart by a catastrophic collision. (B) There was not enough material in this part of the solar nebula to form a planet. (C) There was too much rocky material to form a terrestrial planet but not enough gaseous material to form a jovian planet. (D) Resonance with Jupiter prevented material from collecting together to form a planet.\nA: Let's think step by step. The asteroid belt is a stellar disc consisting of a large number of asteroids between Mars and Jupiter's orbits. The asteroids in this belt are affected by the gravitational pull from both other asteroids and nearby planets. Due to the strong gravitational force of Jupiter there are resonances that give rise to low density regions of asteroids known as the Kirkwood gap. So (B) and (C) are not correct since it\u2019s not a lack of material that prevents a planet from being formed, and (A) is incorrect because the Kirkwood gap would have prevented a planet from forming in the first place, and (D) is the correct option. The answer is (D).\n\nQ: Why is Mars red?\n(A) Because the surface is covered with heavily oxidized (\"rusted\") minerals. (B) Because the atmosphere scatters more light at bluer wavelengths transmitting mostly red light. (C) Because Mars is covered with ancient lava flows which are red in color. (D) Because flowing water on Mars's surface altered the surface minerals several billion years ago.\nA: Let's think step by step. Option (B) is not correct because if the red color was caused by the scattering off the atmosphere, then the earth with a much thicker atmosphere would also look red. Options (C) and (D) are not specific enough about why the color of the surface would be red, while (A) is correct because it explains that the surface is red due to the rusted materials on the surface and the red color comes from the rust. So the correct option is (A). The answer is (A).\n\n", "business_ethics": "The following are multiple choice questions (with answers) about business ethics.\n\nQ: In contrast to _______, _______ aim to reward favourable behaviour by companies. The success of such campaigns have been heightened through the use of ___________, which allow campaigns to facilitate the company in achieving _________ .\n(A) Buycotts, Boycotts, Blockchain technology, Charitable donations (B) Buycotts, Boycotts, Digital technology, Increased Sales (C) Boycotts, Buyalls, Blockchain technology, Charitable donations (D) Boycotts, Buycotts, Digital technology, Increased Sales\nA: Let's think step by step. We refer to Wikipedia articles on business ethics for help. The sentence that best uses the possible options above is \u201cIn contrast to *boycotts*, *buycotts* aim to reward favourable behavior by companies. The success of such campaigns have been heightened through the use of *digital technology*, which allow campaigns to facilitate the company in achieving *increased sales*.\u201d The answer is (D).\n\nQ: _______ is the direct attempt to formally or informally manage ethical issues or problems, through specific policies, practices and programmes.\n(A) Corporate social responsibility (B) Business ethics management (C) Sustainability (D) Environmental management\nA: Let's think step by step. We refer to Wikipedia articles on business ethics for help. The direct attempt manage ethical issues through specific policies, practices, and programs is business ethics management. The answer is (B).\n\nQ: Three contrasting tactics that CSO's can engage in to meet their aims are ________ which typically involves research and communication, ________, which may involve physically attacking a company's operations or ________, often involving some form of _______.\n(A) Non-violent direct action, Violent direct action, Indirect action, Boycott (B) Indirect action, Instrumental action, Non-violent direct action, Information campaign (C) Indirect action, Violent direct action, Non-violent direct-action Boycott (D) Non-violent direct action, Instrumental action, Indirect action, Information campaign\nA: Let's think step by step. We refer to Wikipedia articles on business ethics for help. The sentence that best uses the possible options above is \u201cThree contrasting tactics that CSO's can engage in to meet their aims are *indirect action*, which typically involves research and communication, *violent direct action*, which may involve physically attacking a company's operations or *non-violent direct action*, often involving some form of *boycott*.\u201d The answer is (C).\n\nQ: To ensure the independence of the non-executive board members, there are a number of steps which can be taken, which include non-executives being drawn from _______ the company, being appointed for a _________ time period as well as being appointed _________.\n(A) Outside, Limited, Independently (B) Inside, Limited, Intermittently (C) Outside, Unlimited, Intermittently (D) Inside, Unlimited, Independently\nA: Let's think step by step. We refer to Wikipedia articles on business ethics for help. The sentence that best uses the possible options above is \u201cTo ensure the independence of the non-executive board members, there are a number of steps which can be taken, which include non-executives being draw from *outside* the company, being appointed for a *limited* time period as well as being imported *independently*. The answer is (A).\n\nQ: Beyond the business case for engaging in CSR there are a number of moral arguments relating to: negative _______, the _______that corporations possess and the ________ of business and society.\n(A) Externalities, Power, Independence (B) Publicity, Insubstantial resources, Mutual dependence (C) Publicity, Power, Independence (D) Externalities, Power, Mutual dependence\nA: Let's think step by step. We refer to Wikipedia articles on business ethics for help. The sentence that best uses the possible options above is \u201cBeyond the business case for engaging the CSR there are a number of moral arguments relating to: negative *externalities*, the *power* that corporations possess and the *mutual independence* of business and society. The answer is (D).\n\n", "clinical_knowledge": "The following are multiple choice questions (with answers) about clinical knowledge.\n\nQ: Glycolysis is the name given to the pathway involving the conversion of:\n(A) glycogen to glucose-1-phosphate. (B) glycogen or glucose to fructose. (C) glycogen or glucose to pyruvate or lactate. (D) glycogen or glucose to pyruvate or acetyl CoA.\nA: Let's think step by step. We refer to Wikipedia articles on clinical knowledge for help. Glycolysis is the name given to the pathway involving conversion of glycogen or glucose to pyruvate or lactate. The answer is (C).\n\nQ: What is the difference between a male and a female catheter?\n(A) Male and female catheters are different colours. (B) Male catheters are longer than female catheters. (C) Male catheters are bigger than female catheters. (D) Female catheters are longer than male catheters.\nA: Let's think step by step. We refer to Wikipedia articles on clinical knowledge for help. The difference between a male and female catheter is that male catheters tend to be longer than female catheters. The answer is (B).\n\nQ: How many attempts should you make to cannulate a patient before passing the job on to a senior colleague, according to the medical knowledge of 2020?\n(A) 4 (B) 3 (C) 2 (D) 1\nA: Let's think step by step. We refer to Wikipedia articles on clinical knowledge for help. According to the medical protocol as of 2020, you should make two attempts to cannulate a patient before passing the job on to a more-senior practitioner. The answer is (C).\n\nQ: In the assessment of the hand function which of the following is true?\n(A) Abduction of the thumb is supplied by spinal root T2 (B) Opposition of the thumb by opponens policis is supplied by spinal root T1 (C) Finger adduction is supplied by the median nerve (D) Finger abduction is mediated by the palmar interossei\nA: Let's think step by step. We refer to Wikipedia articles on clinical knowledge for help. Of all the options, it is only true that the opposition of the thumb by opponens pollicis is supplied by spinal root T1. The answer is (B).\n\nQ: The energy for all forms of muscle contraction is provided by:\n(A) ATP. (B) ADP. (C) phosphocreatine. (D) oxidative phosphorylation.\nA: Let's think step by step. We refer to Wikipedia articles on clinical knowledge for help. The energy for muscular contraction is provided by ATP (adenosine triphosphate), which is the powerhouse of the cell. The answer is (A).\n\n", "college_biology": "The following are multiple choice questions (with answers) about college biology.\n\nQ: Which of the following represents an accurate statement concerning arthropods?\n(A) They possess an exoskeleton composed primarily of peptidoglycan. (B) They possess an open circulatory system with a dorsal heart. (C) They are members of a biologically unsuccessful phylum incapable of exploiting diverse habitats and nutrition sources. (D) They lack paired, jointed appendages.\nA: Let's think step by step. Peptidoglycan is known to comprise the plasma membrane of most bacteria, rather than the exoskeleton of arthropods, which is made of chitin, which rules out (A). The answer (C) is false because arthropods are a highly successful phylum. Likewise, arthropods have paired, jointed appendages, which rules out (D). The only remaining option is (B), as arthropods have an open circulatory system with a dorsal tubular heart. The answer is (B).\n\nQ: In a given population, 1 out of every 400 people has a cancer caused by a completely recessive allele, b. Assuming the population is in Hardy-Weinberg equilibrium, which of the following is the expected proportion of individuals who carry the b allele but are not expected to develop the cancer?\n(A) 1/400 (B) 19/400 (C) 20/400 (D) 38/400\nA: Let's think step by step. According to the Hardy Weinberg Law, $p^2 + 2 p q + q^2 = 1$, and $p + q = 1$ where $p$ is the frequency of the dominant allele, $q$ is the frequency of the recessive allele, and $p^2$, $q^2$, and $2pq$ are the frequencies of dominant homozygous, recessive homozygous, and heterozygous individuals, respectively. \u200bThe frequency of the recessive allele (q) is $\\sqrt{\frac{1}{400}} = 0.05$. We have $p = 1 - q = 0.95$. The frequency of heterozygous individuals is $2pq = 2 \\cdot 0.05 \\cdot 0.95 = 0.095$. The number of heterozygous individuals is equal to the frequency of heterozygous individuals times the size of the population, or $0.095 * 400 = 38$. So we end up with 38/400. The answer is (D).\n\nQ: According to the pressure-flow model of movement of phloem contents, photosynthate movement from source to sink is driven by\n(A) an ATP-dependent pressure-flow pump (B) a water-pressure potential gradient (C) transpiration (D) apoplastic diffusion\nA: Let's think step by step. It is a gradient in water pressure that induces the movement of phloem content, which refers to answer (B). The mechanism of movement does not rely on metabolism, which rules out (A). Transpiration refers to the exhalation of water vapor through plant stomata, and is also not related, which rules out (C). While the apoplastic pathway is one of two main pathways for water transport in plants, it is not central to the pressure flow model, which rules out (D). The answer is (B).\n\nQ: Which of the following contain DNA sequences required for the segregation of chromosomes in mitosis and meiosis?\n(A) Telomeres (B) Centromeres (C) Nucleosomes (D) Spliceosomes\nA: Let's think step by step. The genetic material in Telomeres is not used, which rules out (A). Nucleosomes are the repeating subunit that comprises chromatin packed in a cell nucleus, and do not specifically refer to DNA sequences necessary for segregating chromosomes in cell division, which rules out (C). A spliceosome is a large ribonucleoprotein that removes introns from transcribed pre-mRNA rather than governing chromosome segregation. Centromeres are directly responsible for segregating chromosomes in cell division. The answer is (B).\n\nQ: The presence of homologous structures in two different organisms, such as the humerus in the front limb of a human and a bird, indicates that\n(A) the human and bird are polyphyletic species (B) a human's and bird's evolution is convergent (C) the human and bird belong to a clade (D) the human and bird developed by analogy\nA: Let's think step by step. Polyphyletic species are organisms that are grouped due to having similar characteristics but which do not have a common ancestor. This is not the case for humans and birds, which rules out (A). Convergent evolution refers to the indepdendent development of similar features in different species at different periods, which is also not the case for humans and birds, which rules out (B). Analogy refers to the superficial resemblance of structures that have different origins, which is not the case for the human and bird forearms, which rules out (D). Humans and birds do belong to the same clade - a group of organisms composed of a common ancestor. The answer is (C).\n\n", "college_chemistry": "The following are multiple choice questions (with answers) about college chemistry.\n\nQ: 3 Cl\u2212(aq) + 4 CrO_4^2\u2212(aq) + 23 H+(aq) \u2192 3 HClO2(aq) + 4 Cr3+(aq) + 10 H2O(l). In the reaction shown above, Cl\u2212(aq) behaves as\n(A) an acid (B) a base (C) a catalyst (D) a reducing agent\nA: Let's think step by step. A molecule that behaves as a base accepts an H+ ion (or proton) from another molecule, whereas a molecule that behaves as an acid donates an H+ ion (or proton) to another molecule. Neither of these is the case for Cl in this reaction, which rules out (A) and (B). A catalyst is a substance that only accelerates a reaction without itself undergoing chemical change, which is not the case here. This rules out (C). Instead, the $Cl^{-} molecules carry a negative charge, which they donate in the reaction to form 3 HClO2. This is the behavior of a reducing agent, or (D). The answer is (D).\n\nQ: Which of the following statements about the lanthanide elements is NOT true?\n(A) The most common oxidation state for the lanthanide elements is +3. (B) Lanthanide complexes often have high coordination numbers (> 6). (C) All of the lanthanide elements react with aqueous acid to liberate hydrogen. (D) The atomic radii of the lanthanide elements increase across the period from La to Lu.\nA: Let's think step by step. The atomic radii of the lanthanide elements in fact decrease across the period from La to Lu. Options (A), (B), and (C) are all true. This means that only (D) is NOT true. The answer is (D).\n\nQ: Which of the following lists the hydrides of group-14 elements in order of thermal stability, from lowest to highest?\n(A) PbH4 < SnH4 < GeH4 < SiH4 < CH4 (B) PbH4 < SnH4 < CH4 < GeH4 < SiH4 (C) CH4 < SiH4 < GeH4 < SnH4 < PbH4 (D) CH4 < PbH4 < GeH4 < SnH4 < SiH4\nA: Let's think step by step. The thermal stability of group-14 hydrides decreases as we move from the top of group 14 to the bottom. The order of elements in the group from top to bottom is C, Si, Ge, Sn, Pb. Therefore in order of increasing thermal stability we have PbH4, SnH4, GeH4, SiH4, and CH4, or answer (A). The answer is (A).\n\nQ: Predict the number of lines in the EPR spectrum of a solution of 13C-labelled methyl radical (13CH3\u2022), assuming the lines do not overlap.\n(A) 4 (B) 3 (C) 6 (D) 24 (E) 8\nA: Let's think step by step. The electron paramagnetic resonance spectrum will be split by two forms of interactions. The first is the hyperfine interaction with the 13C (nuclear spin $I = \nrac{1}{2}$) which will split the spectrum into 2 lines. This will be further split into 4 lines by the interaction with three equivalent 1H nuclei. The total number of lines is therefore $2 \\cdot 4 = 8$. The answer is (E).\n\n", "college_computer_science": "The following are multiple choice questions (with answers) about college computer science.\n\nQ: Which of the following regular expressions is equivalent to (describes the same set of strings as) (a* + b)*(c + d)?\n(A) a*(c + d)+ b(c + d)\n(B) a*(c + d)* + b(c + d)*\n(C) a*(c + d)+ b*(c + d)\n(D) (a + b)*c +(a + b)*d\nA: Let's think step by step. We know that:\n1. (X* + Y)* = (X + Y)*\n2. X(Y + Z)? = XY + XZ\nUsing equation 1 we can rewrite (a* + b)*(c + d)? as:\n3. (a + b)*(c + d)?\nUsing equation 2 we can rewrite equation 3 as:\n(a + b)*c + (a + b)*d The answer is (D).\n\nQ: The Singleton design pattern is used to guarantee that only a single instance of a class may be instantiated. Which of the following is (are) true of this design pattern?\nI. The Singleton class has a static factory method to provide its instance.\nII. The Singleton class can be a subclass of another class.\nIII. The Singleton class has a private constructor.\n(A) I only\n(B) II only\n(C) III only\n(D) I, II, and III\nA: Let's think step by step. Statement I is a correct statement about a Singleton, because a Singleton restricts instantiation to a single, static method. Statement II is also correct, because there is no inherent restriction regarding the inheritance of a Singleton. Statement III is also correct, because a Singletons must be instantiated only once, so its constructor is made private to prevent any construction except via its static factory method.\nGiven these facts, statements I, II, and III are all correct. The answer is (D).\n\nQ: A certain pipelined RISC machine has 8 general-purpose registers R0, R1, . . . , R7 and supports the following operations:\nADD Rs1, Rs2, Rd (Add Rs1 to Rs2 and put the sum in Rd)\nMUL Rs1, Rs2, Rd (Multiply Rs1 by Rs2 and put the product in Rd)\nAn operation normally takes one cycle; however, an operation takes two cycles if it produces a result required by the immediately following operation in an operation sequence.\nConsider the expression AB + ABC + BC, where variables A, B, C are located in registers R0, R1, R2. If the contents of these three registers must not be modified, what is the minimum number of clock cycles required for an operation sequence that computes the value of AB + ABC + BC?\n(A) 5 (B) 6 (C) 7 (D) 8\nA: Let's think step by step. First, we are given that A is in R0, B is in R1, and C is in R2.\nNext, we can see that we must compute three multiplies (AB, BC, and ABC) and two adds (AB + ABC, (AB + ABC) + BC) to compute our final answer, resulting in a minimum of five clock cycles.\nNext, we can see that there is no way to avoid at least one pipeline stall when computing our final answer, because to compute our final sum we must wait at least one cycle for the results from the previous stage to be ready. Thus, our minimum number of cycles must be 6.\nWe can verify that we can create a solution that requires only six cycles as follows:\ncompute AB: MUL R0, R1, R3\ncompute BC: MUL R1, R2, R4\ncompute ABC: MUL R3, R4, R5\ncompute AB + BC: ADD R3, R4, R6\nSTALL\ncompute AB + ABC + BC: ADD R5, R6, R7\nSo there are 6 cycles. The answer is (B).\n\nQ: A compiler generates code for the following assignment statement.\nG := (A + B) * C - (D + E) * F\nThe target machine has a single accumulator and a single-address instruction set consisting of instructions load, store, add, subtract, and multiply. For the arithmetic operations, the left operand is taken from the accumulator and the result appears in the accumulator. The smallest possible number of instructions in the resulting code is\n(A) 5 (B) 6 (C) 7 (D) 9\nA: Let's think step by step. We can compute the final answer with the following sequence of operations:\n1. LOAD D (accumulator = D)\n2. ADD E (accumulator = D+E)\n3. MUL F (accumulator = (D+E)*F)\n4. STORE X (X = (D+E)*F)\n5. LOAD A (accumulator = A)\n6. ADD B (accumulator = A+B)\n7. MUL C (accumulator = (A+B)*C)\n8. SUB X (accumulator = (A+B)*C - (D+E)*F)\n9. STORE G (G = (A+B)*C - (D+E)*F)\nThis sequence takes 9 instructions. The answer is (D).\n\nQ: Consider a computer design in which multiple processors, each with a private cache memory, share global memory using a single bus. This bus is the critical system resource. Each processor can execute one instruction every 500 nanoseconds as long as memory references are satisfied by its local cache. When a cache miss occurs, the processor is delayed for an additional 2,000 nanoseconds. During half of this additional delay, the bus is dedicated to serving the cache miss. During the other half, the processor cannot continue, but the bus is free to service requests from other processors. On average, each instruction requires 2 memory references. On average, cache misses occur on 1 percent of references. What proportion of the capacity of the bus would a single processor consume, ignoring delays due to competition from other processors?\n(A) 1/50 (B) 1/27 (C) 1/25 (D) 2/27\nA: Let's think step by step. We know that each instruction requires two memory references per instruction, and that there is an average cache miss rate of one percent.\nThus a given processor has:\n(1 cache miss / 100 references) * (2 references / instruction) =\n(2 cache misses / 100 instructions), so:\nmisses_per_instruction = 1 cache miss / 50 instructions.\nNext, we know that each instruction requires 500 nanoseconds when there is no cache miss, and 500 + 2000 = 2500 nanoseconds when there is a cache miss. Thus:\n50 instructions / (49 * 500) + (1 * 2500) nanoseconds, so:\ninstructions_per_ns = 50 instructions / 27000 nanoseconds.\nNow, we know that each cache miss locks the bus for half of the 2000 nanosecond cache miss delay, or 1000 nanoseconds, so:\nlock_ns_per_miss = 1000 nanoseconds / cache miss.\nThus we can see that on average a single processor will lock the bus for:\nlock_ns_per_miss * misses_per_instruction * instructions_per_ns =\n(1000 nanoseconds / cache miss) * (1 cache miss / 50 instructions) * (50 instructions / 27000 nanoseconds) = 1000 * (1/50) * (50/27000) = 1000/27000 = 1/27. The answer is (B).\n\n", "college_mathematics": "The following are multiple choice questions (with answers) about college mathematics.\n\nQ: Let V be the set of all real polynomials p(x). Let transformations T, S be defined on V by T:p(x) -> xp(x) and S:p(x) -> p'(x) = d/dx p(x), and interpret (ST)(p(x)) as S(T(p(x))). Which of the following is true?\n(A) ST = 0 (B) ST = T (C) ST = TS (D) ST - TS is the identity map of V onto itself.\nA: Let's think step by step. For a given polynomial $p$ we have\n\\[ST(p) = (xp(x))\u2019 = p(x) + xp\u2019(x)\\]\nand\n\\[TS(p) = xp\u2019(x).\\]\nHence \\[ST(p) - TS(p) = p(x) + xp\u2019(x) - xp\u2019(x).\\] The answer is (D).\n\nQ: Suppose that f(1 + x) = f(x) for all real x. If f is a polynomial and f(5) = 11, then f(15/2)\n(A) -11 (B) 0 (C) 11 (D) 33/2\nA: Let's think step by step. The only polynomial so that $f(1 + x) = f(x)$ is a constant polynomial. Hence $f(5) = 11 = f(15/2)$. The answer is (C).\n\nQ: Let A be a real 2x2 matrix. Which of the following statements must be true?\nI. All of the entries of A^2 are nonnegative.\nII. The determinant of A^2 is nonnegative.\nIII. If A has two distinct eigenvalues, then A^2 has two distinct eigenvalues.\n(A) I only (B) II only (C) III only (D) II and III only\nA: Let's think step by step. We have \\[ det(A^2) = (det(A))^2 \\geq 0,\\] hence II holds.\nIII is false: as a counterexample take a diagonal matrix with -1 and 1 on the diagonal. Then $A^2$ is the identity matrix. The answer is (B).\n\nQ: Let A be the set of all ordered pairs of integers (m, n) such that 7m + 12n = 22. What is the greatest negative number in the set B = {m + n : (m, n) \\in A}?\n(A) -5 (B) -4 (C) -3 (D) -2\nA: Let's think step by step. We have 12n = 22 - 7m and one of the solutions is $m = -2$, $n = 3$. Then $m + n = 1$, hence we need to look for smaller $m$ in order to make $m + n$ negative. The next solution is $m = -14$ and $n = 10$. For smaller $m$ we have $m + n$ smaller than $-4$. The answer is (B).\n\nQ: A tank initially contains a salt solution of 3 grams of salt dissolved in 100 liters of water. A salt solution containing 0.02 grams of salt per liter of water is sprayed into the tank at a rate of 4 liters per minute. The sprayed solution is continually mixed with the salt solution in the tank, and the mixture flows out of the tank at a rate of 4 liters per minute. If the mixing is instantaneous, how many grams of salt are in the tank after 100 minutes have elapsed?\n(A) 2 (B) 2 - e^-2 (C) 2 + e^-2 (D) 2 + e^-4\nA: Let's think step by step. For all $t \\in \\mathbb{R}$, let $s(t)$ denote the number grams of salt in the tank at the $t$ minute mark. Then $s(0) = 3$.\nWe use $s$ and $s(t)$ interchangeably. We also use $s^{\\prime}$ and $s^{\\prime}(t)$ interchangeably. The solution sprayed into the tank adds $(0.02) 4=2 / 25$ grams of salt per minute. There are always 100 liters of liquid in the tank, containing $s$ grams of salt. So the density of salt in the tank is $s / 100$ grams per liter. The flow of water out of the tank therefore subtracts $4(s / 100)=s / 25$ grams of salt per minute. Then, for all $t \\in \\mathbb{R}$, we have $s^{\\prime}(t)=(2 / 25)-(s / 25)=(2-s) / 25$, and so $[s(t)=2] \\Rightarrow\\left[s^{\\prime}(t)=0\right]$. For all $t \\in \\mathbb{R}$,\n$$\n\frac{d}{d t}[\\ln (s-2)]=\frac{s^{\\prime}}{s-2}=\frac{-1}{25}=\frac{d}{d t}\\left[-\frac{t}{25}\right] .\n$$\nChoose $C \\in \\mathbb{R}$ such that, for all $t \\in \\mathbb{R}, \\ln ((s(t)-2))=-[t / 25]+C$. Let $K:=e^{C}$. Then, for all $t \\in \\mathbb{R}$, we have $(s(t))-2=K e^{-t / 25}$, and so $s(t)=2+K e^{-t / 25}$. Then $3=s(0)=2+K e^{0}=2+K$, so $K=1$. Then $s(100)=2+K e^{-100 / 25}=2+1 \\cdot e^{-4}=2+e^{-4}$. The answer is (D).\n\n", "college_medicine": "The following are multiple choice questions (with answers) about college medicine.\n\nQ: An expected side effect of creatine supplementation is:\n(A) muscle weakness. (B) gain in body mass. (C) muscle cramps. (D) loss of electrolytes.\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. Creatine supplementation is a dietary supplement that results in body mass gain. The answer is (B).\n\nQ: Which of the following is not a true statement?\n(A) Muscle glycogen is broken down enzymatically to glucose-1-phosphate (B) Elite endurance runners have a high proportion of Type I fibres in their leg muscles (C) Liver glycogen is important in the maintenance of the blood glucose concentration (D) Insulin promotes glucose uptake by all tissues in the body\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. Let\u2019s solve this step by step and go over each choice: \n(A) \u201cMuscle glycogen is broken down enzymatically to glucose-1-phosphate\u201d: This is a correct statement.\n(B) \u201cElite endurance runners have a high proportion of Type I fibres in their leg muscles\u201d: This is a correct statement.\n(C) \u201cLiver glycogen is important in the maintenance of the blood glucose concentration\u201d: This is a correct statement. \n(D) \u201cInsulin promotes glucose uptake by all tissues in the body\u201d: This is not a correct statement, because insulin promotes glucose uptake by the liver, adipose tissue, and muscle, but not all tissues. For instance, the tissues in the brain and red blood cells are not affected by insulin. The answer is (D).\n\nQ: A high school science teacher fills a 1 liter bottle with pure nitrogen and seals the lid. The pressure is 1.70 atm, and the room temperature is 25\u00b0C. Which two variables will both increase the pressure of the system, if all other variables are held constant?\n(A) Increasing temperature, increasing moles of gas (B) Increasing temperature, increasing volume (C) Decreasing volume, decreasing temperature (D) Decreasing moles of gas, increasing volume\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. The relevant equation for this is the ideal gas law: PV=nRT. To increase the pressure of the system (P), then either n (number of moles of the gas) or T (temperature) have to increase. The answer is (A).\n\nQ: In a genetic test of a newborn, a rare genetic disorder is found that has X-linked recessive transmission. Which of the following statements is likely true regarding the pedigree of this disorder?\n(A) All descendants on the maternal side will have the disorder. (B) Females will be approximately twice as affected as males in this family. (C) All daughters of an affected male will be affected. (D) There will be equal distribution of males and females affected.\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. Let\u2019s solve this step by step. Let's recall first that females have two X chromosomes, while males have one X and one Y chromosome. This is an important fact we need to know before answering this question. \nBecause a male can only pass his only one X chromosome to a daughter, if he is affected by this rare genetic disorder, then we know for sure that he will pass this rare genetic disorder to all his future-born daughters. Therefore, \u201c(C): All daughters of an affected male will be affected\u201d is a correct statement. The answer is (C).\n\nQ: Glucose is transported into the muscle cell:\n(A) via protein transporters called GLUT4. (B) only in the presence of insulin. (C) via hexokinase. (D) via monocarbylic acid transporters.\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. Glucose (also known as the blood sugar) is the main sugar found in the human body. It is transported into the muscle cell via diffusion through protein transporters called GLUT4. The answer is (A).\n\n", "college_physics": "The following are multiple choice questions (with answers) about college physics.\n\nQ: A refracting telescope consists of two converging lenses separated by 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification of the telescope is\n(A) 4 (B) 5 (C) 6 (D) 20\nA: Let's think step by step. In a refracting telescope, if both lenses are converging, the focus of both lenses must be between the two lenses, and thus the focal lengths of the two lenses must add up to their separation. Since the focal length of one lens is 20 cm, the focal length of the other must be 80 cm. The magnification is the ratio of these two focal lengths, or 4. The answer is (A).\n\nQ: The muon decays with a characteristic lifetime of about 10^-6 second into an electron, a muon neutrino, and an electron antineutrino. The muon is forbidden from decaying into an electron and just a single neutrino by the law of conservation of\n(A) charge (B) mass (C) energy and momentum (D) lepton number\nA: Let's think step by step. Lepton number must be conserved, meaning the total number of leptons minus the number of antileptons. If a muon decays into an electron and a single neutrino, the total lepton number would go from one to two, violating lepton number conservation. The answer is (D).\n\nQ: One end of a Nichrome wire of length 2L and cross-sectional area A is attached to an end of another Nichrome wire of length L and cross- sectional area 2A. If the free end of the longer wire is at an electric potential of 8.0 volts, and the free end of the shorter wire is at an electric potential of 1.0 volt, the potential at the junction of the two wires is most nearly equal to\n(A) 2.4 V (B) 3.3 V (C) 4.5 V (D) 5.7 V\nA: Let's think step by step. This is a simple voltage divider problem, where the longer wire has a resistance four times that of the shorter end. So the voltage divider ratio is 1 / 5, meaning that the potential in the middle is 1.0 V + (8.0 V - 1.0 V) * 1/5 = 2.4 V. The answer is (A).\n\nQ: A refracting telescope consists of two converging lenses separated by 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification of the telescope is\n(A) 4 (B) 5 (C) 6 (D) 20\nA: Let's think step by step. In a refracting telescope, if both lenses are converging, the focus of both lenses must be between the two lenses, and thus the focal lengths of the two lenses must add up to their separation. Since the focal length of one lens is 20 cm, the focal length of the other must be 80 cm. The magnification is the ratio of these two focal lengths, or 4. The answer is (A).\n\nQ: For which of the following thermodynamic processes is the increase in the internal energy of an ideal gas equal to the heat added to the gas?\n(A) Constant temperature (B) Constant volume (C) Constant pressure (D) Adiabatic\nA: Let's think step by step. Heat added to the gas can go into the gases internal energy or work done against an external force. However, if the volume of the gas container is constant, no work will be done (since work is pressure times change in volume). So, at constant volume, all of the heat goes into the internal energy. The answer is (B).\n\n", "computer_security": "The following are multiple choice questions (with answers) about computer security.\n\nQ: SHA-1 has a message digest of\n(A) 160 bits (B) 512 bits (C) 628 bits (D) 820 bits\nA: Let's think step by step. Since SHA-1 is a hash function which takes an input and produces a 160-bit (20-byte) hash value, its message digest is 160 bits. The answer is (A).\n\nQ: _____________ can modify data on your system \u2013 so that your system doesn\u2019t run correctly or you can no longer access specific data, or it may even ask for ransom in order to give your access.\n(A) IM \u2013 Trojans (B) Backdoor Trojans (C) Trojan-Downloader (D) Ransom Trojan\nA: Let's think step by step. The system is asking for trojans, which are for ransom, which means ransom trojan. The answer is (D).\n\nQ: What is ethical hacking?\n(A) \"Hacking\" ethics so they justify unintended selfish behavior (B) Hacking systems (e.g., during penetration testing) to expose vulnerabilities so they can be fixed, rather than exploited (C) Hacking into systems run by those whose ethics you disagree with (D) A slang term for rapid software development, e.g., as part of hackathons\nA: Let's think step by step. Ethical hacking is a process of detecting vulnerabilities in an application, system, or organization's infrastructure that an attacker can use to exploit an individual or organization. They use this process to prevent cyberattacks and security breaches by lawfully hacking into the systems and looking for weak points. The answer is (B).\n\nQ: The ____________ is anything which your search engine cannot search.\n(A) Haunted web (B) World Wide Web (C) Surface web (D) Deep Web\nA: Let's think step by step. The search engine searches on the Surface Web, which is the portion of the world wide web which is visible so (B,C) are wrong. The Haunted Web doesn\u2019t correspond to an internet concept. The Deep Web is the part of the World Wide Web which is not indexed. The answer is (D).\n\nQ: Exploitation of the Heartbleed bug permits\n(A) overwriting cryptographic keys in memory (B) a kind of code injection (C) a read outside bounds of a buffer (D) a format string attack\nA: Let's think step by step. The Heartbleed Bug is a serious vulnerability in the popular OpenSSL cryptographic software library. Heartbleed resulted from improper input validation (due to a missing bounds check) in the implementation of the TLS heartbeat extension. The vulnerability was classified as a buffer over-read, a situation where more data can be read than should be allowed. The answer is (C).\n\n", "conceptual_physics": "\nThe following are multiple choice questions (with answers) about conceptual physics.\n\nQ: Colors in a soap bubble result from light\n(A) converted to a different frequency (B) deflection (C) interference (D) polarization\nA: Let's think step by step. In a soap bubble film, the light bounces between the two soap-air interfaces many times, interfering with itself constructively or destructively depending on the width of the film. This results in different colors being visible. The answer is (C).\n\nQ: Compared with the mass of a uranium atom undergoing fission, the combined masses of the products after fission are\n(A) less (B) more (C) the same (D) zero\nA: Let's think step by step. Fission releases energy, which comes from the rest mass of its initial nucleus. Thus the mass of the products is less than the mass of the reactant uranium nucleus. The answer is (A).\n\nQ: Things that are equivalent according to the equivalence principle are\n(A) space and time. (B) a traveling twin and a stay-at-home twin. (C) gravity and acceleration. (D) mass and energy.\nA: Let's think step by step. Einstein\u2019s famous equivalence principle states that gravity and acceleration are equivalent. The answer is (C).\n\nQ: Which of these three elements has the most mass per nucleon?\n(A) Hydrogen (B) Iron (C) Uranium (D) Same in each\nA: Let's think step by step. Due to nuclear binding energy, the mass of an atomic nucleus is less than the sum of individual masses of the free constituent protons and neutrons; this is known as the mass defect. Hydrogen has no mass defect because it has only a single nucleon, so it will have the most mass per nucleon. The answer is (A).\n\nQ: A model airplane flies slower when flying into the wind and faster with wind at its back. When launched at right angles to the wind a cross wind its groundspeed compared with flying in still air is\n(A) the same (B) greater (C) less (D) either greater or less depending on wind speed\nA: Let's think step by step. The plane\u2019s speed in the direction of the wind is greater than it would be in the absence of wind, and its direction orthogonal to the wind is the same as it would be in the absence of the wind. The total speed, which is these two components added in quadrature, is thus greater than the speed in still air. The answer is (B).\n\n", "econometrics": "The following are multiple choice questions (with answers) about econometrics.\n\nQ: Suppose now that a researcher wishes to use information criteria to determine the optimal lag length for a VAR. 500 observations are available for the bi-variate VAR, and the values of the determinant of the variance-covariance matrix of residuals are 0.0336, 0.0169, 0.0084, and 0.0062 for 1, 2, 3, and 4 lags respectively. What is the optimal model order according to Akaike's information criterion?\n(A) 1 lag (B) 2 lags (C) 3 lags (D) 4 lags\nA: Let's think step by step. We refer to Wikipedia articles on econometrics for help. Let\u2019s solve this problem step by step. First of all, let\u2019s recall that for a given set of data, Akaike's information criterion (AIC) allows us to measure how well a statistical model fits the data; it is an estimator of prediction error. Here in this problem we will need to use the formula ln(det(sigma_hat)) + (2 * k / T) to determine the values of Akaike\u2019s criterion, where ln denotes the natural log function, det the determinant function, k the total number of parameters in total (across both equations), and T the number of observations (which, in this case, is equal to 500). For 1 lag, the number of parameters in total is equal to 6; for 2 lags, it is 10; for 3 lags, it is 14; and for 4 lags, it is 18. Now, let\u2019s calculate the values of the criterion for each lag:\n(A) 1 lag: ln(0.0336) + (2 * 6 / 500) = ln(0.0336) + (12 / 500) = -3.369\n(B) 2 lags: ln(0.0169) + (2 * 10 / 500) = ln(0.0169) + (20 / 500) = -4.040\n(C) 3 lags: ln(0.0084) + (2 * 14 / 500) = ln(0.0084) + (28 / 500) =-4.724\n(D) 4 lags: ln(0.0062) + (2 * 18 / 500) = ln(0.0062) + (36 / 500) =-5.011\nBecause the optimal model order according to AIC minimizes the information criterion, the answer should be the one with the lowest value. In this case, (D) has the lowest value. The answer is (C).\n\nQ: Consider the following AR(1) model with the disturbances having zero mean and unit variance\nyt = 0.2 + 0.4 yt-1 + ut\nThe (unconditional) mean of y will be given by\n(A) 0.2 (B) 0.4 (C) 0.5 (D) 0.33\nA: Let's think step by step. We refer to Wikipedia articles on econometrics for help. Let\u2019s solve this problem step by step. If we have a an AR(1) model with the disturbances having zero mean and unit variance, then the unconditional mean of y is equal to the following:\nunconditional mean of y = (the intercept term) / (1 - autoregressive coefficient)\nWe know that the intercept term is 0.2 and the autoregressive coefficient is 0.4; thus, we have:\nunconditional mean of y = (0.2) / (1 - 0.4) = (0.2) / (0.6) = 2 / 6 = 1 / 3, which is approximately 0.33. That means that the answer should be (D) 0.33. The answer is (D).\n\nQ: What would be then consequences for the OLS estimator if heteroscedasticity is present in a regression model but ignored?\n(A) It will be biased (B) It will be inconsistent (C) It will be inefficient (D) All of (a), (b) and (c) will be true.\nA: Let's think step by step. We refer to Wikipedia articles on econometrics for help. Heteroscedasticity refers to the condition where the variance of the error terms is not constant across multiple observations. If heteroscedasticity is present in a regression model, then the coefficient estimates in the OLS estimator will be not only unbiased and consistent but also inefficient. Because (A) and (B) are incorrect choices and (C) is a correct choice, (D) cannot be the right answer. Ultimately, (C) is the only true choice. The answer is (C).\n\nQ: Suppose that a test statistic has associated with it a p-value of 0.08. Which one of the following statements is true?\n(i) If the size of the test were exactly 8%, we would be indifferent between rejecting and not rejecting the null hypothesis\n(ii) The null would be rejected if a 10% size of test were used\n(iii) The null would not be rejected if a 1% size of test were used\n(iv) The null would be rejected if a 5% size of test were used.\n(A) (ii) and (iv) only (B) (i) and (iii) only (C) (i), (ii), and (iii) only (D) (i), (ii), (iii), and (iv).\nA: Let's think step by step. We refer to Wikipedia articles on econometrics for help. Let\u2019s reason about each of the options.\n(i) is a true statement.\n(ii) is a true statement.\n(iii) is a true statement.\n(iv) is not a true statement. Thus, (i), (ii), and (iii) are true. The answer is (C).\n\nQ: For a stationary autoregressive process, shocks will\n(A) Eventually die away (B) Persist indefinitely (C) Grow exponentially (D) Never occur\nA: Let's think step by step. We refer to Wikipedia articles on econometrics for help. This is a formal logic problem about stationally process. For a stationary autoregressive process, shocks will eventually die away. The answer is (A).\n\n", "electrical_engineering": "\nThe following are multiple choice questions (with answers) about electrical engineering.\n\nQ: A point pole has a strength of 4\u03c0 * 10^-4 weber. The force in newtons on a point pole of 4\u03c0 * 1.5 * 10^-4 weber placed at a distance of 10 cm from it will be\n(A) 15 N. (B) 20 N. (C) 7.5 N. (D) 3.75 N.\nA: Let's think step by step. The force between two point poles is given by m_1m_2/(mu_0 4 \\pi r^2), in analogy to Coulomb\u2019s law. Plugging in the values given in the question, we calculate that the force is approximately 15 N. The answer is (A).\n\nQ: The coil of a moving coil meter has 100 turns, is 40 mm long and 30 mm wide. The control torque is 240*10-6 N-m on full scale. If magnetic flux density is 1Wb/m2 range of meter is\n(A) 1 mA. (B) 2 mA. (C) 3 mA. (D) 4 mA.\nA: Let's think step by step. The torque on a coil in a uniform magnetic field is given by BANI, where B is the magnetic flux density, A is the area of the coil, N is the number of turns, and I is the current. So we have that I = (Torque)/(BAN), or 240e-6/(1200e-6 * 100 * 1) = 2e-3. The answer is (B).\n\nQ: In an SR latch built from NOR gates, which condition is not allowed\n(A) S=0, R=0 (B) S=0, R=1 (C) S=1, R=0 (D) S=1, R=1\nA: Let's think step by step. An SR latch is a set-reset latch; in the case where S=1 and R=1, the circuit has no stable state; instead a race condition will be produced within the circuit, so the device will be in an undefined state. So S=1, R=1 is an illegal input. The answer is (D).\n\nQ: Two long parallel conductors carry 100 A. If the conductors are separated by 20 mm, the force per meter of length of each conductor will be\n(A) 100 N. (B) 0.1 N. (C) 1 N. (D) 0.01 N.\nA: Let's think step by step. The magnetic force-per-length between two current-carrying conductors is given by \\mu_0 I_1 I_2 / (2 \\pi r), where $r$ is the separation distance and I_1 and I_2 are the currents. Plugging in 100 A for I_1 and I_2, and 20 mm for r, gives 0.1 N. The answer is (B).\n\nQ: In a 2 pole lap winding dc machine , the resistance of one conductor is 2\u03a9 and total number of conductors is 100. Find the total resistance\n(A) 200\u03a9 (B) 100\u03a9 (C) 50\u03a9 (D) 10\u03a9\nA: Let's think step by step. In lap winding, effectively two resistors are connected in parallel, so the actual resistance of each pair is 1 Ohm. Since we have 50 pairs, we get a total resistance of 50 Ohms. The answer is (C).\n\n", "elementary_mathematics": "The following are multiple choice questions (with answers) about elementary mathematics.\n\nQ: Olivia used the rule \"Add 11\" to create the number pattern shown below. 10, 21, 32, 43, 54. Which statement about the number pattern is true?\n(A) The 10th number in the pattern will be an even number.\n(B) The number pattern will never have two even numbers next to each other.\n(C) The next two numbers in the pattern will be an even number then an odd number.\n(D) If the number pattern started with an odd number then the pattern would have only odd numbers in it.\nA: Let's think step by step. Choice A is incorrect because every even-numbered term in the pattern is odd, and 10 is an even number. Choice B is correct, because adding an odd number (in this case 11) to an odd number produces an even number, and adding an odd number to an even number produces an odd number. Thus the terms in the pattern will alternate between odd and even, so there will never be two even numbers next to each other. Choice C is incorrect because the last term in the example is even (54), and we know that the terms will alternate between even and odd. Choice D is incorrect because the terms in the pattern will alternate between odd and even, regardless of the value of the first term. The answer is (B).\n\nQ: The population of the city where Michelle was born is 145,826. What is the value of the 5 in the number 145,826?\n(A) 5 thousands\n(B) 5 hundreds\n(C) 5 tens\n(D) 5 ones\nA: Let's think step by step. Choice A is correct, because there are three digits following the 5, so\nthe 5 is in the thousands place. Thus the other choices are incorrect. The answer is (A).\n\nQ: A store sells 107 different colors of paint. They have 25 cans of each color in storage. The number of cans of paint the store has in storage can be found using the expression below. 107 \u00d7 25. How many cans of paint does the store have in storage?\n(A) 749\n(B) 2,675\n(C) 2,945\n(D) 4,250\nA: Let's think step by step. We can calculate 107 x 25 = (100 x 25) + (7 x 25) = 2500 + 175 = 2675. The answer is (B).\n\nQ: A total of 30 players will play basketball at a park. There will be exactly 5 players on each team. Which statement correctly explains how to find the number of teams needed?\n(A) Add 5 to 30 to find 35 teams.\n(B) Divide 30 by 5 to find 6 teams.\n(C) Multiply 30 and 5 to find 150 teams.\n(D) Subtract 5 from 30 to find 25 teams.\nA: Let's think step by step. We want to find the number of teams. We know that there are 5 players/team, and 30 players. Thus to get the number of teams we divide players by players/team, so 30 players / 5 players/team = 6 teams. The answer is (B).\n\nQ: Which expression is equivalent to 5 x 9?\n(A) (5 x 4) x (6 x 5)\n(B) (5 x 5) + (5 x 4)\n(C) (5 x 5) + (5 x 9)\n(D) (5 x 9) x (6 x 9)\nA: Let's think step by step. We know that 9 = (5 + 4), so 5 x 9 = 5 x (5 + 4) = (5 x 5) + (5 x 4). The answer is (B).\n\n", "formal_logic": "The following are multiple choice questions (with answers) about formal logic.\n\nQ: Which of the given formulas of PL is the best symbolization of the following sentence?\nTurtles live long lives and are happy creatures, unless they are injured.\n(A) (L \u2022 H) \u2261 I (B) (L \u2022 H) \u2228 I (C) L \u2022 (H \u2228 I) (D) L \u2022 (H \u2283 R).\nA: Let's think step by step. We refer to Wikipedia articles on formal logic for help. Let\u2019s solve this step by step. Let \u201cL\u201d denote \u201cliving long\u201d, H \u201cbeing happy\u201d, and \u201cI\u201d \u201cbeing injured\u201d. Now, consider each choice:\n(A) means (living long AND being happy) is equivalent to (being injured). \n(B) means (living long AND being happy) OR (being injured). \n(C) means (living long) AND (being happy OR being injured). \n(D) means (living long) AND (being happy implies being R), but what R denotes is not clear.\nObviously, (B) is the best symbolization of the original sentence. The answer is (B).\n\nQ: Select the best translation into predicate logic.George borrows Hector's lawnmower. (g: George; h: Hector; l: Hector's lawnmower; Bxyx: x borrows y from z).\n(A) Blgh (B) Bhlg (C) Bglh (D) Bghl\nA: Let's think step by step. We refer to Wikipedia articles on formal logic for help. Let\u2019s solve this step by step. We are told that \u201cBxyx\u201d means \u201cx borrows y from z\u201d. We can rewrite \u201cGeorge borrows Hector's lawnmower\u201d as \u201cGeorge borrows a lawnmower from Hector\u201d, which can then be translated into predicate logic as \u201cBglh\u201d. The answer \u201cBglh\u201d appears in (C); therefore, (C) must be the correct answer. The answer is (C).\n\nQ: \nSelect the best English interpretation of the given arguments in predicate logic.\nDm\n(\u2200x)(Wx \u2283 ~Dx). \n(\u2200x)Wx \u2228 Ag\t/ (\u2203x)Ax\n(A) Marina is a dancer. Some weaklings are not dancers. Either everything is a weakling or Georgia plays volleyball. So something plays volleyball. (B) Marina is a dancer. No weakling is a dancer. Everything is either a weakling or plays volleyball. So something plays volleyball. (C) Marina is a dancer. Some weaklings are not dancers. Everything is either a weakling or plays volleyball. So something plays volleyball. (D) Marina is a dancer. No weakling is a dancer. Either everything is a weakling or Georgia plays volleyball. So something plays volleyball.\nA: Let's think step by step. We refer to Wikipedia articles on formal logic for help. Let\u2019s solve this step by step. Let \u201cD\u201d denote \u201cbeing a dancer\u201d, \u201cm\u201d denote \u201cMaria\u201d, \u201cg\u201d denote \u201cGeorgia\u201d, \u201cW\u201d denote \u201cweakling\u201d, \u201cA\u201d denote \u201cplaying volleyball\u201d. Then, we have the following:\n1. Dm \u2192 Maria is a dance.\n2. (\u2200x)(Wx \u2283 ~Dx). \u2192 For all x, if x is a weakling, then x is not a dancer. In other words, no weakling is a dancer.\n3. (\u2200x)Wx \u2228 Ag\t/ (\u2203x)Ax \u2192 For all x, x is a weakling or Georgia plays volleyball. So there exists an x that plays volleyball. \nOptions (A) and (C) do claim that some weaklings are not dancers, but the second argument strongly states that no weakling is a dancer. Thus, we can eliminate them. Option (B) omits the important detail about Georgia playing volleyball. Option (D) has all the details presented in the arguments and is the best English interpretation of the arguments. The answer is (D).\n\nQ: Select the best translation into predicate logic: No people drive on Mars.\n(A) ~Pd (B) (\u2200x)(Px \u2228 ~Dx) (C) (\u2200x)(Px \u2283 ~Dx) (D) ~Dp\nA: Let's think step by step. We refer to Wikipedia articles on formal logic for help. Let\u2019s solve this step by step. Let \u201cP\u201d denote \u201cbeing on Mars\u201d and \u201cD\u201d denote \u201cdriving on Mars\u201d. Then let\u2019s consider each option:\nOption (A): ~Pd \u2192 d is not on Mars.\nOption (B): (\u2200x)(Px \u2228 ~Dx) \u2192 For all x, x is on Mars and x do not drive on Mars.\nOption (C): (\u2200x)(Px \u2283 ~Dx) \u2192 For all x, x is on Mars implies that x do not drive on Mars.\nOption (D): ~Dp: \u2192 p do not drive on Mars.\nOf all these options, Option (C) appears to be the best and most meaningful interpretation of the argument \u201cNo people drive on Mars.\u201d The answer is (C).\n\n", "global_facts": "The following are multiple choice questions (with answers) about global facts.\n\nQ: As of 2017, how many of the world\u2019s 1-year-old children today have been vaccinated against some disease? *\n(A) 80% (B) 60% (C) 40% (D) 20%\nA: Let's think step by step. We refer to Wikipedia articles on global facts for help. According to data published by the World Health Organization, the nummber of 1-year-old children vaccinated in 2017 exceeds 80%. The answer is (A).\n\nQ: As of 2019, about what percentage of Americans agree that the state is run for the benefit of all the people?\n(A) 31% (B) 46% (C) 61% (D) 76%\nA: Let's think step by step. We refer to Wikipedia articles on global facts for help. In 2019, about 46% percentage of Americans agree that the state is run for the benefit of all the people. The answer is (B).\n\nQ: As of 2019, about what percentage of Russians say it is very important to have free media in our country without government/state censorship?\n(A) 38% (B) 53% (C) 68% (D) 83%\nA: Let's think step by step. We refer to Wikipedia articles on global facts for help. As of 2019, about 38% of Russians say it is very important to have free media in our country. The answer is (A).\n\nQ: As of 2015, since 1990 forests have ____ in Europe and have ____ in Africa and the Americas.\n(A) increased, increased (B) increased, decreased (C) decreased, increased (D) decreased, decreased\nA: Let's think step by step. We refer to Wikipedia articles on global facts for help. As of 2015, since 1990 forests have increased in Europe and have decreased in Africa and the Americas. The answer is (B).\n\nQ: Which of the following pairs of statements are both true (as of 2019)?\n(A) People tend to be optimistic about their own future and the future of their nation or the world. (B) People tend to be optimistic about their own future but pessimistic about the future of their nation or the world. (C) People tend to be pessimistic about their own future but optimistic about the future of their nation or the world. (D) People tend to be pessimistic about their own future and the future of their nation or the world.\nA: Let's think step by step. We refer to Wikipedia articles on global facts for help. As of 2019, most people tend to be optimistic about their own future but pessimistic about the future of their nation or the world. The answer is (B).\n\n", "high_school_biology": "The following are multiple choice questions (with answers) about high school biology.\n\nQ: In animal cells, which of the following represents the most likely pathway that a secretory protein takes as it is synthesized in a cell?\n(A) Plasma membrane\u2013Golgi apparatus\u2013ribosome\u2013secretory vesicle\u2013rough ER (B) Ribosome\u2013Golgi apparatus\u2013rough ER\u2013secretory vesicle\u2013plasma membrane (C) Plasma membrane\u2013Golgi apparatus\u2013ribosome\u2013secretory vesicle\u2013rough ER (D) Ribosome\u2013rough ER\u2013Golgi apparatus\u2013secretory vesicle\u2013plasma membrane\nA: Let's think step by step. Protein synthesis starts at the ribosome, so we can eliminate (A) and (C). The ribosome is often in the endoplasmic reticulum and moves from there to the Golgi apparatus, where it is modified and packaged into a vesicle. The vesicle then floats to the plasma membrane and is secreted. The answer is (D).\n\nQ: A mutation in a bacterial enzyme changed a previously polar amino acid into a nonpolar amino acid. This amino acid was located at a site distant from the enzyme\u2019s active site. How might this mutation alter the enzyme\u2019s substrate specificity?\n(A) By changing the enzyme\u2019s pH optimum (B) By changing the enzyme\u2019s location in the cell (C) By changing the shape of the protein (D) An amino acid change away from the active site cannot alter the enzyme\u2019s substrate specificity.\nA: Let's think step by step. A change in an amino acid leads to a change in the primary structure of the protein. A change in the primary structure may lead to a change in the secondary and the tertiary structure of the protein. A change in the tertiary structure means a change in the shape of the protein, so (C) has to be correct. Since the change does not affect the active site of the enzyme, we do not expect the activity of the enzyme to be affected. The answer is (C).\n\nQ: Which of the following is not a way to form recombinant DNA?\n(A) Translation (B) Conjugation (C) Specialized transduction (D) Transformation\nA: Let's think step by step. The introduction of foreign DNA or RNA into bacteria or eukaryotic cells is a common technique in molecular biology and scientific research. There are multiple ways foreign DNA can be introduced into cells including transformation, transduction, conjugation, and transfection. In contrast, (A) is not a way to form DNA: during translation the ribosomes synthesize proteins from RNA. The answer is (A).\n\nQ: Homologous structures are often cited as evidence for the process of natural selection. All of the following are examples of homologous structures EXCEPT\n(A) the wings of a bird and the wings of a bat (B) the flippers of a whale and the arms of a man (C) the pectoral fins of a porpoise and the flippers of a seal (D) the forelegs of an insect and the forelimbs of a dog\nA: Let's think step by step. \u200b\u200bHomologous structures are similar physical features in organisms that share a common ancestor \u200b\u200bbut different functions. Comparisons (B) and (C) are clearly homologous because they share a common ancestor and the structures serve different purposes. Bat wings and birg wings are also homologous, while they are both wings, the forelimbs serve different purposes. Insects and dogs are very far ancestors since one is vertebrate while the other is invertebrate and the forelimbs serve the same purpose, so they are not homologous. The answer is (D).\n\nQ: Which of the following is not known to be involved in the control of cell division?\n(A) Cyclins (B) Protein kinases (C) Checkpoints (D) Fibroblast cells\nA: Let's think step by step. Normal cells move through the cell cycle in a regulated way. At the checkpoint stage, they use information about their own internal state and cues from the environment around them to decide whether to proceed with cell division. Cues like these act by changing the activity of core cell cycle regulators inside the cell. The most common regulators are cyclins and cyclin-dependent kinases. Fibroblast cells do not play any role in cell division. The answer is (D).\n\n", "high_school_chemistry": "The following are multiple choice questions (with answers) about high school chemistry.\n\nQ: Which of the following is considered an acid anhydride?\n(A) HCl (B) H2SO3 (C) SO2 (D) Al(NO3)3\nA: Let's think step by step. An acid anhydride is a compound that is derived by removing water from an acid. The chemical formula for water is H2O, which means that we need to determine which of these options, when combined with H2O, forms an acid. SO2, or Sulfur dioxide, when combined with H2O, makes H2SO4, or sulfuric acid. The answer is (C).\n\nQ: Which of the following is expected to be a polar molecule?\n(A) PCl4F (B) BF3 (C) CO2 (D) Si(CH3)4\nA: Let's think step by step. A polar molecule is one that has a slightly positive charge on one end of the molecule and a slightly negative charge on the other end. Boron trifluoride (BF3) has Boron as the center atom and three fluorine atoms attached to it; it is trigonal planar and symmetric, so it is nonpolar. Carbon Dioxide (CO2) has Carbon as the central atom with double bonds to two Oxygen atoms - this is also symmetrical and therefore nonpolar. The same is the case for tetramethyl silane (SI(CH3)4), which is a Silicon atom surrounded by four methyl groups. The structure of PCL4F is that Phosphorus is the central atom, attached to four chlorines and one fluorine atom. This is asymmetrical, and therefore has a net dipole and is expected to be a polar molecule. The answer is (A).\n\nQ: From the solubility rules, which of the following is true?\n(A) All chlorides, bromides, and iodides are soluble (B) All sulfates are soluble (C) All hydroxides are soluble (D) All ammonium-containing compounds are soluble\nA: Let's think step by step. The chlorides, bromides, and iodides of lead, silver, and mercury are not soluble in water. This rules out (A). The sulfates of lead, barium, and calcium are not soluble in water, which rules out (B). The hydroxides of any metal besides sodium, potassium, ammonium, calcium, and barium are insoluble. This rules out (C). Typically ammonium ions indicate a soluble ionic substance. The answer is (D).\n\nQ: A new compound is synthesized and found to be a monoprotic acid with a molar mass of 248 g/mol. When 0.0050 mol of this acid are dissolved in 0.500 L of water, the pH is measured as 3.89. What is the pKa of this acid?\n(A) 3.89 (B) 7.78 (C) 5.78 (D) 2.33\nA: Let's think step by step. Recall that $[A] = [H^{+}]$. Here, this is equal to $$10^{-3.89}$. Then we have $K_{a} = $\nrac{[H^{+}][A^{-}]}{[HA]} = \nrac{10^{-3.89} \\cdot 10^{-3.89}}{10^{-2}}. The resulting exponent is $-3.89 + (-3.89) - (-2) = 5.78$, therefore $K_a = 10^{-5.78}$. The $pK_a$ is the negative log of $K_a$, which is equal to $5.78$. The answer is (C).\n\nQ: A solution contains 2.00 mole of acetic acid, CH3COOH, and 1.00 mole of calcium acetate, Ca(CH3COO)2. The solution is able to resist the addition of a small amount of strong acid or strong base with only minor changes in the pH of the solution. Larger quantities of strong acid or strong base can cause a significant change in pH. How many moles of nitric acid, HNO3, may be added before the pH begins to change significantly?\n(A) 0.500 mole (B) 1.00 mole (C) 2.00 mole (D) 3.00 mole\nA: Let's think step by step. We would like to compute the buffer capacity of this solution. First we write the equation for the ionization of the weak acid, in this case of acetic acid. $CH_{3}COOH (aq) + H_{2}O \nightarrow H_{3}O^{+} + CH3COO^{-}$. The conjugate base is therefore the acetate ion. The added strong acid, Nitric acid, will react with the conjugate base. Therefore the maximum amount of acid that can be added will be equal to the amount of acetate ion, or 2 moles. The answer is (C).\n\n", "high_school_computer_science": "The following are multiple choice questions (with answers) about high school computer science.\n\nQ: Which of the following is an example of the use of a device on the Internet of Things (IoT) ?\n(A) A car alerts a driver that it is about to hit an object. (B) A hiker uses a G P S watch to keep track of her position. (C) A refrigerator orders milk from an online delivery service when the milk in the refrigerator is almost gone. (D) A runner uses a watch with optical sensors to monitor his heart rate.\nA: Let's think step by step. The term Internet of Things (IoT) refers to common devices which are connected to the internet, enabling new functionality. Choice A is incorrect because it does not describe an internet connected device. In choice B, the watch is only described as having GPS functionality but no internet connectivity. Choice C describes a common device (a refrigerator) which has internet connectivity enabling new functionality (online ordering). Choice D does not mention internet connectivity for the watch, only optical sensors. The answer is (C).\n\nQ: Many Web browsers allow users to open anonymous windows. During a browsing session in an anonymous window, the browser does not record a browsing history or a list of downloaded files. When the anonymous window is exited, cookies created during the session are deleted. Which of the following statements about browsing sessions in an anonymous window is true?\n(A) The activities of a user browsing in an anonymous window will not be visible to people who monitor the user's network, such as the system administrator. (B) Items placed in a Web store's shopping cart for future purchase during the anonymous browsing session will not be saved on the user's computer. (C) A user will not be able to log in to e-mail or social media accounts during the anonymous browsing session. (D) A user browsing in an anonymous window will be protected from viruses launched from any web sites visited or files downloaded.\nA: Let's think step by step. Choice A is incorrect as it only describes network traffic, which an anonymous browser does not change. Choice B is correct as it correctly describes how an anonymous browser will prevent saving data on the user\u2019s computer after the session is ended. Choice C is incorrect because an anonymous browser will not prevent logging in to email or social media accounts. Choice D is incorrect because an anonymous browser in itself performs no virus protection. The answer is (B).\n\nQ: In the program below, the initial value of X is 5 and the initial value of Y is 10.\nIF (X < 0){\n DISPLAY (\"Foxtrot\")\n} ELSE {\n IF (X > Y){\n DISPLAY (\"Hotel\")\n } ELSE {\n IF (Y > 0){\n DISPLAY (\"November\")\n } ELSE {\n DISPLAY (\"Yankee\")\n }\n }\n}\nWhat is displayed as a result of running the program?\n(A) Foxtrot (B) Hotel (C) November (D) Yankee\nA: Let's think step by step. Because X has the value 5, the first conditional IF (X < 0) is false, so we move to the first ELSE clause. Because X is 5 and Y is 10, the second conditional IF (X > Y) is false, so we move to the following ELSE clause. Since Y is 10, the conditional IF (Y > 0) is true, so the command DISPLAY (\"November\") is executed. The answer is (C).\n\nQ: What is the output of \"abc\"[::-1] in Python 3?\n(A) Error (B) abc (C) cba (D) c\nA: Let's think step by step. We know that the slicing operator [::-1] takes all of the elements in the string in reverse order, so we reverse the order of the string \"abc\", resulting in \"cba\". The answer is (C).\n\nQ: A list of numbers has n elements, indexed from 1 to n. The following algorithm is intended to display the number of elements in the list that have a value greater than 100. The algorithm uses the variables count and position. Steps 3 and 4 are missing.\n Step 1: Set count to 0 and position to 1.\n Step 2: If the value of the element at index position is greater than 100, increase the value of count by 1.\n Step 3: (missing step)\n Step 4: (missing step)\n Step 5: Display the value of count.\nWhich of the following could be used to replace steps 3 and 4 so that the algorithm works as intended?\n(A) Step 3: Increase the value of position by 1.\n Step 4: Repeat steps 2 and 3 until the value of count is greater than 100.\n(B) Step 3: Increase the value of position by 1.\n Step 4: Repeat steps 2 and 3 until the value of position is greater than n.\n(C) Step 3: Repeat step 2 until the value of count is greater than 100.\n Step 4: Increase the value of position by 1.\n(D) Step 3: Repeat step 2 until the value of position is greater than n.\n Step 4: Increase the value of count by 1.\nA: Let's think step by step. Choice A is incorrect, because its Step 4 has an incorrect termination condition, stopping when count is greater than 100. We need to stop after inspecting all elements in the list. Choice B is correct because it correctly increments both count and position, and correctly repeats these steps and terminates when all elements in the list have been inspected. Choice C is incorrect because it incorrectly increments the variable count until its value is greater than 100, regardless of the elements in the list. Choice D is incorrect because its step 3 does not increment the value of position, so it will repeat forever. The answer is (B).\n\n", "high_school_european_history": "The following are multiple choice questions (with answers) about high school european history.\n\nQ: This question refers to the following information.\nAlbeit the king's Majesty justly and rightfully is and ought to be the supreme head of the Church of England, and so is recognized by the clergy of this realm in their convocations, yet nevertheless, for corroboration and confirmation thereof, and for increase of virtue in Christ's religion within this realm of England, and to repress and extirpate all errors, heresies, and other enormities and abuses heretofore used in the same, be it enacted, by authority of this present Parliament, that the king, our sovereign lord, his heirs and successors, kings of this realm, shall be taken, accepted, and reputed the only supreme head in earth of the Church of England, called Anglicans Ecclesia; and shall have and enjoy, annexed and united to the imperial crown of this realm, as well the title and style thereof, as all honors, dignities, preeminences, jurisdictions, privileges, authorities, immunities, profits, and commodities to the said dignity of the supreme head of the same Church belonging and appertaining; and that our said sovereign lord, his heirs and successors, kings of this realm, shall have full power and authority from time to time to visit, repress, redress, record, order, correct, restrain, and amend all such errors, heresies, abuses, offenses, contempts, and enormities, whatsoever they be, which by any manner of spiritual authority or jurisdiction ought or may lawfully be reformed, repressed, ordered, redressed, corrected, restrained, or amended, most to the pleasure of Almighty God, the increase of virtue in Christ's religion, and for the conservation of the peace, unity, and tranquility of this realm; any usage, foreign land, foreign authority, prescription, or any other thing or things to the contrary hereof notwithstanding.\nEnglish Parliament, Act of Supremacy, 1534\nFrom the passage, one may infer that the English Parliament wished to argue that the Act of Supremacy would\n(A) give the English king a new position of authority (B) give the position of head of the Church of England to Henry VIII alone and exclude his heirs (C) establish Calvinism as the one true theology in England (D) end various forms of corruption plaguing the Church in England\nA: Let's think step by step. We refer to Wikipedia articles on european history for help. The Act of Supremacy states that it grants authority to the king \"to repress and extirpate all errors, heresies, and other enormities and abuses\", referring to the corruption in the Church of England. The answer is (D).\n\nQ: This question refers to the following information.\nRead the following excerpt.\nThe revolutionary seed had penetrated into every country and spread more or less. It was greatly developed under the r\u00e9gime of the military despotism of Bonaparte. His conquests displaced a number of laws, institutions, and customs; broke through bonds sacred among all nations, strong enough to resist time itself; which is more than can be said of certain benefits conferred by these innovators.\nThe monarchs will fulfil the duties imposed upon them by Him who, by entrusting them with power, has charged them to watch over the maintenance of justice, and the rights of all, to avoid the paths of error, and tread firmly in the way of truth. Placed beyond the passions which agitate society, it is in days of trial chiefly that they are called upon to despoil realities of their false appearances, and to show themselves as they are, fathers invested with the authority belonging by right to the heads of families, to prove that, in days of mourning, they know how to be just, wise, and therefore strong, and that they will not abandon the people whom they ought to govern to be the sport of factions, to error and its consequences, which must involve the loss of society.\nUnion between the monarchs is the basis of the policy which must now be followed to save society from total ruin. . . .\nLet them not confound concessions made to parties with the good they ought to do for their people, in modifying, according to their recognized needs, such branches of the administration as require it.\nLet them be just, but strong; beneficent, but strict.\nLet them maintain religious principles in all their purity, and not allow the faith to be attacked and morality interpreted according to the social contract or the visions of foolish sectarians.\nLet them suppress Secret Societies; that gangrene of society.\n\u2014Klemens von Metternich, Political Confession of Faith, 1820\nWhich of the following was the greatest cause of the fears expressed by Metternich in the document above?\n(A) The ideas of personal liberty and nationalism conceived during the Enlightenment resulted in radical revolutions that could spread throughout Europe. (B) The conquest of Europe by Napoleon led to the creation of new factions and shifted the European balance of power. (C) The power of monarchs had grown to the point where it needed to be checked by other powers within each nation or domination of civilians would occur. (D) The rising and falling economic cycle of the newly emerging capitalist economy could lead to civilian unrest that must be suppressed.\nA: Let's think step by step. We refer to Wikipedia articles on european history for help. The fears of revolution in early 19th century Europe expressed by Klemens von Metternich, a conservative Austrian statesman, were a direct result of the age of Enlightenment, a period of European history where the absolute power of the monarchy was challenged with ideas of individual liberty and nationalism, leading to the French revolution and its effects all over Europe. The answer is (A).\n\nQ: This question refers to the following information.\nThe excerpts below are from the Navigation Acts of 1651.\n[A]fter the first day of December, one thousand six hundred fifty and one, and from thence forwards, no goods or commodities whatsoever of the growth, production or manufacture of Asia, Africa or America, or of any part thereof; or of any islands belonging to them, or which are described or laid down in the usual maps or cards of those places, as well of the English plantations as others, shall be imported or brought into this Commonwealth of England, or into Ireland, or any other lands, islands, plantations, or territories to this Commonwealth belonging, or in their possession, in any other ship or ships, vessel or vessels whatsoever, but only in such as do truly and without fraud belong only to the people of this Commonwealth, or the plantations thereof, as the proprietors or right owners thereof; and whereof the master and mariners are also of the people of this Commonwealth, under the penalty of the forfeiture and loss of all the goods that shall be imported contrary to this act, , , ,\n[N]o goods or commodities of the growth, production, or manufacture of Europe, or of any part thereof, shall after the first day of December, one thousand six hundred fifty and one, be imported or brought into this Commonwealth of England, or any other lands or territories to this Commonwealth belonging, or in their possession, in any ship or ships, vessel or vessels whatsoever, but in such as do truly and without fraud belong only to the people of this Commonwealth, and in no other, except only such foreign ships and vessels as do truly and properly belong to the people of that country or place, of which the said goods are the growth, production or manufacture.\nWhich of the following best describes the outcome of the Navigation Acts of 1651?\n(A) They served as a catalyst for the growth of English shipping and overseas trade, but did little to limit the prospects of the Dutch in the seventeenth century. (B) They brought about almost immediate hardships for the Dutch economy as their dominance of overseas trade quickly ended. (C) They were rescinded during the restoration of the Stuarts as they sought normal diplomatic relations with the Dutch so not as to need Parliament's financial support for war. (D) They led to nearly a century of recurrent war between England and the Netherlands, which would not end until after American independence.\nA: Let's think step by step. We refer to Wikipedia articles on european history for help. The Navigation Acts of 1651 helped English shipping by restricting the ability of ships from other European countries, especially the Dutch, to transport goods from colonies in Asia and Africa into England. The answer is (A).\n\nQ: This question refers to the following information.\nIn Russia there was nothing going on well, and [Souvarine] was in despair over the news he had received. His old companions were all turning to the politicians; the famous Nihilists who made Europe tremble-sons of village priests, of the lower middle class, of tradesmen-could not rise above the idea of national liberation, and seemed to believe that the world would be delivered-when they had killed their despot&\u2026\n\"Foolery! They'll never get out of it with their foolery.\"\nThen, lowering his voice still more, in a few bitter words he described his old dream of fraternity. He had renounced his rank and his fortune; he had gone among workmen, only in the hope of seeing at last the foundation of a new society of labour in common. All the sous in his pockets had long gone to the urchins of the settlement; he had been as tender as a brother with the colliers, smiling at their suspicion, winning them over by his quiet workmanlike ways and his dislike of chattering. But decidedly the fusion had not taken place.\nHis voice changed, his eyes grew bright, he fixed them on \u00e9tienne, directly addressing him:\n\"Now, do you understand that? These hatworkers at Marseilles who have won the great lottery prize of a hundred thousand francs have gone off at once and invested it, declaring that they are going to live without doing anything! Yes, that is your idea, all of you French workmen; you want to unearth a treasure in order to devour it alone afterwards in some lazy, selfish corner. You may cry out as much as you like against the rich, you haven't got courage enough to give back to the poor the money that luck brings you. You will never be worthy of happiness as long as you own anything, and your hatred of the bourgeois proceeds solely from an angry desire to be bourgeois yourselves in their place.\"\n\u00e9mile Zola, French writer, Germinal, 1885\nThe passage displays the direct concern for the welfare of the working classes that was typically a part of which movement?\n(A) Capitalist (B) Scientific (C) Communist (D) Existentialist\nA: Let's think step by step. We refer to Wikipedia articles on european history for help. The modern Communist movement aims to establish a classless society based on communal ownership and distribution of property and means of production, thereby especially benefiting the working classes. The answer is (C).\n\nQ: This question refers to the following information.\nThe following excerpt is from a pamphlet.\nYou will do me the justice to remember, that I have always strenuously supported the Right of every man to his own opinion, however different that opinion might be to mine. He who denies to another this right, makes a slave of himself to his present opinion, because he precludes himself the right of changing it.\nThe most formidable weapon against errors of every kind is Reason. I have never used any other, and I trust I never shall.\nThe circumstance that has now taken place in France of the total abolition of the whole national order of priesthood, and of everything appertaining to compulsive systems of religion, and compulsive articles of faith, has not only precipitated my intention, but rendered a work of this kind exceedingly necessary, lest in the general wreck of superstition, of false systems of government, and false theology, we lose sight of morality, of humanity, and of the theology that is true.\nI believe in one God, and no more; and I hope for happiness beyond this life.\nI believe in the equality of man; and I believe that religious duties consist in doing justice, loving mercy, and endeavoring to make our fellow-creatures happy.\nI do not believe in the creed professed by the Jewish church, by the Roman church, by the Greek church, by the Turkish church, by the Protestant church, nor by any church that I know of. My own mind is my own church.\nAll national institutions of churches, whether Jewish, Christian or Turkish, appear to me no other than human inventions, set up to terrify and enslave mankind, and monopolize power and profit.\nI do not mean by this declaration to condemn those who believe otherwise; they have the same right to their belief as I have to mine.\n\u2014Thomas Paine, The Age of Reason, 1794\u20131795\nWhich of the following Enlightenment philosophes designed a system of checks and balances for government to avoid abuses of power?\n(A) Jean Jacques Rousseau (B) Baron Montesquieu (C) Mary Wollstonecraft (D) Adam Smith\nA: Let's think step by step. We refer to Wikipedia articles on european history for help. Baron Montesquieu was a 18th centrury French philsopher who wrote extensively against the monoplization of power and advocated for a system of checks and balances in government to prevent the rise of despotism. The answer is (B).\n\n", "high_school_geography": "The following are multiple choice questions (with answers) about high school geography.\n\nQ: Which one of the following items is an example of nonmaterial culture?\n(A) Dove soap (B) Dove candy bar (C) Dove symbol (D) A dove (bird).\nA: Let's think step by step. We refer to Wikipedia articles on geography for help. Nonmaterial culture consists of cultural ideas, beliefs or symbols that are not physical objects. The answer is (C).\n\nQ: During the third stage of the demographic transition model, which of the following is true?\n(A) Birth rates increase and population growth rate is less rapid. (B) Birth rates decline and population growth rate is less rapid. (C) Birth rates increase and population growth rate increases. (D) Birth rates decrease and population growth rate increases.\nA: Let's think step by step. We refer to Wikipedia articles on geography for help. The demographic transition model models the five different stages of population growth as a country goes through economic development, where the third stage refers to a period of declining birth rates and lower population growth. The answer is (B).\n\nQ: The practice of hiring a foreign third-party service provider to run an operation is called\n(A) outsourcing. (B) offshoring. (C) maquiladoras. (D) locational interdependence.\nA: Let's think step by step. We refer to Wikipedia articles on geography for help. \"Offshoring\" literally means to move or base some of the activities or processes of a company to a foreign country. The answer is (B).\n\nQ: Which of the following statements is NOT accurate regarding the services provided by local governments in the United States?\n(A) Duplication of efforts occurs often. (B) Social problems of the central city spill over into the surrounding residential suburbs. (C) Inefficiency in providing services occurs often. (D) One neighborhood's efforts to reduce pollution are always supported by neighboring communities.\nA: Let's think step by step. We refer to Wikipedia articles on geography for help. There may be economic, social or political reasons for two neighboring communities and their local governments not agreeing to pollution reduction efforts initiated by one of them. The answer is (D).\n\nQ: The rate of natural increase of a population is found by subtracting the\n(A) crude death rate from the crude birth date. (B) crude birth rate from the crude death rate. (C) doubling time from the crude birth rate. (D) fertility rate from the crude death rate.\nA: Let's think step by step. We refer to Wikipedia articles on geography for help. The difference between number of births and deaths gives the population increase at any given time. The answer is (A).\n\n", "high_school_government_and_politics": "The following are multiple choice questions (with answers) about high school government and politics.\n\nQ: Which of the following best states an argument made by James Madison in The Federalist number 10?\n(A) Honest politicians can prevent factions from developing. (B) Factions are more likely to occur in large republics than in small ones. (C) The negative effects of factionalism can be reduced by a republican government. (D) Free elections are the people's best defense against factionalism.\nA: Let's think step by step. We refer to Wikipedia articles on government and politics for help. In the Federalist number 10, James Madison advocated for a representative republican form of government to guard against factionalism. The answer is (C).\n\nQ: The term \"budget deficit\" refers to the\n(A) annual increase in federal spending on the military (B) amount of interest on the national debt (C) difference between the initial budget proposals made by the president and Congress (D) amount the government spends in excess of its revenues\nA: Let's think step by step. We refer to Wikipedia articles on government and politics for help. When the goverment spends more than it earns, their difference is the budget deficit. The answer is (D).\n\nQ: Which of the following statements about cabinet departments is FALSE?\n(A) They are established by the legislative branch. (B) Their members often don't have much influence over presidential decisions. (C) They cannot all be run by leaders who belong to the same political party the president does. (D) Not every federal agency is a cabinet department.\nA: Let's think step by step. We refer to Wikipedia articles on government and politics for help. There is no law stipulating that some cabinet department leaders have to belong to a political party different from that of the president. The answer is (C).\n\nQ: Which of the following cases established the precedent that a defendant must be informed of the right to remain silent, the right to a lawyer, and protection from self-incrimination?\n(A) Weeks v. United States (B) Betts v. Brady (C) Mapp v. Ohio (D) Miranda v. Arizona\nA: Let's think step by step. We refer to Wikipedia articles on government and politics for help. In the landmark Miranda v. Arizona in 1966, the US Supreme Court, based on the Fifth and Sixth Amendment of the US Constitution, guaranteed a defendant's right to an attorney and protection from self-incrimination. The answer is (D).\n\nQ: Uncertainty over the limits to presidential power is caused primarily by the fact that\n(A) the constitutional definition of those powers is broad and unspecific (B) most people agree that the Constitution places too many limits on presidential power (C) the Supreme Court consistently refuses to rule on cases concerning presidential powers (D) constitutional amendments have greatly increased presidential powers\nA: Let's think step by step. We refer to Wikipedia articles on government and politics for help. The US Constitution is not very specific about the powers of the president, leading to uncertainty over its limits. The answer is (A).\n\n", "high_school_macroeconomics": "The following are multiple choice questions (with answers) about high school macroeconomics.\n\nQ: Which of the following policies best describes supply-side fiscal policy?\n(A) An increase in the money supply (B) Increased government spending (C) Lower taxes on research and development of new technology (D) Higher taxes on household income\nA: Let's think step by step. We refer to Wikipedia articles on macroeconomics for help. Supply-side fiscal policy stimulates the economy by encouraging more production of goods and services through reduction in taxes and deregulation. The answer is (C).\n\nQ: The short-run Phillips curve indicates a\n(A) direct relation between unemployment and inflation (B) direct relation between price and quantity demanded (C) inverse relation between price and quantity demanded (D) inverse relation between unemployment and inflation\nA: Let's think step by step. We refer to Wikipedia articles on macroeconomics for help. The short-run Phillips curve shows that whenever unemployment decreases below a natural level, the inflation starts increasing, and vice-versa. The answer is (D).\n\nQ: Holding all else equal which of the following monetary policies would be used to boost U.S. exports?\n(A) Increasing the discount rate (B) Increasing the reserve ratio (C) Buying government securities (D) Lowering tariffs\nA: Let's think step by step. We refer to Wikipedia articles on macroeconomics for help. Buying government securities leads to reduction in demand for US dollars from foreign buyers, thereby making it cheaper and hence making US exports more attractive. The answer is (C).\n\nQ: A federal deficit occurs when\n(A) exports exceed imports. (B) imports exceed exports. (C) federal tax collections exceed spending. (D) federal spending exceeds federal tax revenues.\nA: Let's think step by step. We refer to Wikipedia articles on macroeconomics for help. A federal deficit occurs when federal spending exceeds federal income which is primarily from tax revenues. The answer is (D).\n\nQ: Which of the following is not included in the U.S. GDP?\n(A) The U.S. military opens a new base in a foreign country with 1000 U.S. personnel. (B) Japanese consumers buy thousands of CDs produced in the United States. (C) An American pop singer performs a sold-out concert in Paris. (D) A French theatrical production tours dozens of American cities.\nA: Let's think step by step. We refer to Wikipedia articles on macroeconomics for help. The economic transactions related to the performance of the American pop-singer in Paris happens entirely outside the U.S. and hence is not included in the GDP numbers. The answer is (C).\n\n", "high_school_mathematics": "The following are multiple choice questions (with answers) about high school mathematics.\n\nQ: Simplify and write the result with a rational denominator: $$\\sqrt{\\sqrt[3]{\\sqrt{\\frac{1}{729}}}}$$\n(A) \\frac{3\\sqrt{3}}{3} (B) \\frac{1}{3} (C) \\sqrt{3} (D) \\frac{\\sqrt{3}}{3}\nA: Let's think step by step. Factoring $729=3^6$ and combining the roots $\\frac{1}{2}\\frac{1}{3}\\frac{1}{2}=\\frac{1}{12}$, we get that $\\sqrt{\\sqrt[3]{\\sqrt{\\frac{1}{729}}}}=\\left(\\frac{1}{3^6}\\right)^{\\frac{1}{12}}=\\frac{1}{3^{\\frac{1}{2}}}=\\frac{3}{\\sqrt{3}}$ The answer is (D).\n\nQ: Five thousand dollars compounded annually at an $x\\%$ interest rate takes six years to double. At the same interest rate, how many years will it take $\\$300$ to grow to $\\$9600$?\n(A) 12 (B) 1 (C) 30 (D) 5\nA: Let's think step by step. To go from $\\$300$ to $\\$9600$, the value must go up by a factor of $9600/300=32=2^5$. Since at this interest rate it takes six years for it to double, it will take $5*6=30$ years to grow to $\\$9600$. The answer is (C).\n\nQ: Ten students take a biology test and receive the following scores: 45, 55, 50, 70, 65, 80, 40, 90, 70, 85. What is the mean of the students\u2019 test scores?\n(A) 55 (B) 60 (C) 62 (D) 65\nA: Let's think step by step. There are 10 students and the sum of their scores is $45 + 55 + 50 + 70 + 65 + 80 + 40 + 90 + 70 + 85 = 650$, the mean is $650/10=65$. The answer is (D).\n\nQ: The variable $x$ varies directly as the square of $y$, and $y$ varies directly as the cube of $z$. If $x$ equals $-16$ when $z$ equals 2, what is the value of $x$ when $z$ equals $\\frac{1}{2}$?\n(A) -1 (B) 16 (C) -\\frac{1}{256} (D) \\frac{1}{16}\nA: Let's think step by step. We know that $x \\propto y^2$ and $y \\propto z^3$, so $x = k z^6$ for some constant $k$. Plugging in for $x=-16$ and $z=2$, the constant value is $k=\\frac{x}{z^6}=\\frac{-16}{64}=-\\frac{1}{4}$. So, when $z=\\frac{1}{2}$, the value of $x$ is $x=kz^6=-\\frac{1}{4}\\frac{1}{2^6}=-\\frac{1}{256}$. The answer is (C).\n\nQ: Joe was in charge of lights for a dance. The red light blinks every two seconds, the yellow light every three seconds, and the blue light every five seconds. If we include the very beginning and very end of the dance, how many times during a seven minute dance will all the lights come on at the same time? (Assume that all three lights blink simultaneously at the very beginning of the dance.)\n(A) 3 (B) 15 (C) 6 (D) 5\nA: Let's think step by step. The least common multiple of 2, 3 and 5 is 30, so during a 7 minute dance, all the three lights will come on at the same time $2*7+1=15$ times. The answer is (B).\n\n", "high_school_microeconomics": "The following are multiple choice questions (with answers) about high school microeconomics.\n\nQ: Which of the following is necessarily a characteristic of oligopoly?\n(A) Free entry into and exit from the market (B) A few large producers (C) One producer of a good with no close substitutes (D) A homogenous product\nA: Let's think step by step. We refer to Wikipedia articles on microeconomics for help. An oligopoly is when a market is dominated by just one or a few number of sellers or producers. To get oligopoly, the market should have high barriers to new entry, and the product has differentiation. The answer is (B).\n\nQ: If the government subsidizes producers in a perfectly competitive market, then\n(A) the demand for the product will increase (B) the demand for the product will decrease (C) the consumer surplus will increase (D) the consumer surplus will decrease\nA: Let's think step by step. We refer to Wikipedia articles on microeconomics for help. (A) and (B) are wrong because the demand curve does not change at all. If the government subsidizes producers, the supply will increase, and thus the consumer surplus also increases. The answer is (C).\n\nQ: Which of the following is true of a price floor?\n(A) The price floor shifts the demand curve to the left. (B) An effective floor creates a shortage of the good. (C) The price floor shifts the supply curve of the good to the right. (D) To be an effective floor, it must be set above the equilibrium price.\nA: Let's think step by step. We refer to Wikipedia articles on microeconomics for help. Price floor does not shift the demand or shift curve. An effective price floor should be set above the equilibrium price, otherwise the market bears and the floor does not have effective effect. The answer is (D).\n\nQ: The concentration ratio for a monopoly is\n(A) 0 (B) 5 (C) 10 (D) 100\nA: Let's think step by step. We refer to Wikipedia articles on microeconomics for help. The concentration ratio is calculated as the sum of market share of a specific number of largest companies. Monopoly means one company or entity controls the entire market, therefore, the concentration ratio is 100 percent. The answer is (D).\n\nQ: In a competitive labor market for housepainters, which of the following would increase the demand for housepainters?\n(A) An effective minimum wage imposed on this labor market. (B) An increase in the price of gallons of paint. (C) An increase in the construction of new houses. (D) An increase in the price of mechanical painters so long as the output effect exceeds the substitution effect.\nA: Let's think step by step. We refer to Wikipedia articles on microeconomics for help. An increase in the construction of new houses means an increase demand of in-house painting, thus increases the demand for housepainters. The answer is (C).\n\n", "high_school_physics": "The following are multiple choice questions (with answers) about high school physics.\n\nQ: A microwave oven is connected to an outlet, 120 V, and draws a current of 2 amps. At what rate is energy being used by the microwave oven?\n(A) 10 W (B) 30 W (C) 60 W (D) 240 W\nA: Let's think step by step. Rate of energy usage is known as power; in an dissipative electrical circuit, power is given by voltage times current. So in our case, the power is 120 V times 2 amps, or 240 W. The answer is (D).\n\nQ: A point charge, Q = +1 mC, is fixed at the origin. How much work is required to move a charge, Q = +8 \u00b5C, from the point (0, 4 meters) to the point (3 meters, 0)?\n(A) 3.5 J (B) 6.0 J (C) 22.5 J (D) 40 J\nA: Let's think step by step. To calculate the work required to move a charge from one location to another in a fixed electric field, it is enough to calculate the potential difference between the two locations. Here, the potential only depends on the distance between the charges; it\u2019s $k q_1 q_2 / r$, where $k$ is Coulomb\u2019s constant. Plugging in values $q_1 = $ 1 mC, $q_2 = 8 \\mu$ C, gives the answer as 5.992 J, which rounds to 6 J. The answer is (B).\n\nQ: Which of the following conditions will ensure that angular momentum is conserved? I. Conservation of linear momentum II. Zero net external force III. Zero net external torque\n(A) I and II only (B) I and III only (C) II and III only (D) III only\nA: Let's think step by step. Torque is defined as the change in angular momentum; if there is zero external torque, angular momentum is conserved. The answer is (D).\n\nQ: A photocell of work function \u03d5 = 2eV is connected to a resistor in series. Light of frequency f = 1 \u00d7 10^15 Hz hits a metal plate of the photocell. If the power of the light is P = 100 W, what is the current through the resistor?\n(A) 2:00 AM (B) 6:00 AM (C) 12:00 AM (D) 24 A\nA: Let's think step by step. The only answer above which has units of current is D, 24 A. The answer is (D).\n\nQ: A pipe full of air is closed at one end. A standing wave is produced in the pipe, causing the pipe to sound a note. Which of the following is a correct statement about the wave\u2019s properties at the closed end of the pipe?\n(A) The pressure is at a node, but the particle displacement is at an antinode. (B) The pressure is at an antinode, but the particle displacement is at a node. (C) The pressure and the particle displacement are both at nodes. (D) The pressure and the particle displacement are both at antinodes.\nA: Let's think step by step. At the closed end of the pipe, the particles cannot have any net displacement because the pipe closure stops them. So the particle displacement is at a node. This closure also causes the pressure to be maximal, i.e. an antinode. The answer is (B).\n\n", "high_school_psychology": "The following are multiple choice questions (with answers) about high school psychology.\n\nQ: Pascale is interested in the processing strategies children use to learn new information. Pascale would best be classified as what type of psychologist?\n(A) sociocultural (B) clinical (C) cognitive (D) behaviorist\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. Sociocultural psychologist focuses on the effect of societal factors on people. Clinical psychologist focuses on people with mental issues. Cognitive psychologist focuses on how people think and learn, including the processing strategies. Behaviorist focuses more on the environment and experience effect on people. The answer is (C).\n\nQ: According to Caplan's model of consultee-centered case consultation, the consultant is primarily interested in\n(A) identifying the causes and solutions of the client's presenting problems (B) identifying and eliminating the causes of the consultee's difficulties in handling a problem (C) establishing a hierarchy of authority to enable effective decision making (D) presenting a single, well-defined and unambiguous course of action for the consultant to overcome skills deficits\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. Caplan defines two type of consultation. Client-centered case consultation aims to handle client's problems, while consultee-centered case consultation aims to identify the reason of client's difficulty to solve problems. The answer is (B).\n\nQ: According to the Individuals with Disabilities Education Improvement Act, which of the following must an educational agency do before it changes the educational placement of a student with a disability?\n(A) Give the child a trial period in the new environment (B) Notify the parents in writing (C) Obtain school board approval (D) Obtain parental consent\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. When the decision to change the educational placement of a student with a disability is made, the educational agency must notify the parents in writing on that date. The answer is (B).\n\nQ: While swimming in the ocean, Ivan is frightened by a dark shadow in the water even before he has the chance to identify what the shadow is. The synaptic connections taking place during this incident of fright are best described by which of the following?\n(A) Messages are sent from the thalamus directly to the amygdala. (B) Messages are sent from the thalamus to the \"what\" and \"where\" pathways. (C) Messages are sent from the parasympathetic nervous system to the cerebral cortex. (D) Messages are sent from the frontal lobes to the pituitary gland.\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. Our neural system has a mechanism that can respond immediate emotional signal before going to the thought center. In the Ivan's case, messages travel directly from thalamus to amygdala. The answer is (A).\n\nQ: Ani believes that her attitudes and behavior play a central role in what happens to her. Such a belief is likely to be associated with\n(A) a strong superego. (B) low self-esteem. (C) low self-efficacy. (D) an internal locus of control.\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. People with an external locus of control believes fate and luck play an important role in their lives, while people with an internal locus of control believes they control their lives. The answer is (D).\n\n", "high_school_statistics": "The following are multiple choice questions (with answers) about high school statistics.\n\nQ: A new smartwatch is manufactured in one part of a factory, then secured for shipping in another, independent part of the factory. The weight of the smartwatch has a mean of 62 grams and a standard deviation of 1.0 grams. The weight of the packaging (box, user's guide, bubble wrap, etc.) has a mean of 456 grams and a standard deviation of 6 grams. Together, the distribution of the weight of the smartwatch and its packaging would have the following mean and standard deviation:\n(A) Mean 518 grams; standard deviation 7.0 grams (B) Mean 518 grams; standard deviation 3.5 grams (C) Mean 518 grams; standard deviation 6.1 grams (D) Mean 394 grams; standard deviation 6.1 grams\nA: Let's think step by step. Since the weight of the watch and the weight of the packaging are independent random variables, the mean and variance of their sum is equal to the sum of their individual means and variances. So the mean is 62 + 456 = 518 grams, and the variances is 1.0^2 + 6.0^2 = 37, leading to a standard deviation of 6.1 grams. The answer is (C).\n\nQ: After a frost warning was issued, the owner of a large orange grove asked his workers to spray all his trees with water. The water was supposed to freeze and form a protective covering of ice around the orange blossom. Nevertheless, the owner suspected that some trees suffered considerable damage due to the frost. To estimate the proportion of trees that suffered more than 50 percent damage due to the frost, he took a random sample of 100 trees from his grove. What is the response variable in this experiment?\n(A) The proportion of trees that suffered more than 50 percent damage due to frost. (B) The number of trees affected by the frost. (C) The number of trees sampled from the grove. (D) For each sampled tree, whether it suffered more than 50 percent damage or at most 50 percent damage.\nA: Let's think step by step. In this experiment, the response variable is what is measured. For each tree, what is measured is whether or not it suffered more than 50 percent damage due to the frost. The answer is (D).\n\nQ: Suppose X and Y are random variables with E(X) = 37, var(X) = 5, E(Y) = 62, and var(Y) = 12. What are the expected value and variance of the random variable X + Y?\n(A) E(X + Y) = 99, var(X + Y) = 8.5 (B) E(X + Y) = 99, var(X + Y) = 13 (C) E(X + Y) = 99, var(X + Y) = 17 (D) There is insufficient information to answer this question.\nA: Let's think step by step. While means of sums of random variables add (regardless of whether the variables are independent) in order to determine the variance of a sum of random variables, we need to know not just their individual variances but the covariance of the two variables, which is not given in this problem. The answer is (D).\n\nQ: Which of the following sets has the smallest standard deviation? Which has the largest?\nI: {1,2,3}\nII: {-10,10}\nIII: {100}\n(A) I, II (B) II, III (C) III, I (D) III, II\nA: Let's think step by step. The variance of distribution I is the expected squared deviation from its mean (which is 2), so the variance is 2/3 . The variance of distribution II is 10^2 (because both elements are 10 away from the mean of zero). The variance of distribution III is 0, since it has a single entry. So distribution III has the smallest standard deviation and distribution II has the largest. The answer is (D).\n\nQ: Which of the following is a correct statement about correlation?\n(A) If the slope of the regression line is exactly 1, then the correlation is exactly 1. (B) If the correlation is 0, then the slope of the regression line is undefined. (C) Switching which variable is called x and which is called y changes the sign of the correlation. (D) The correlation r is equal to the slope of the regression line when z-scores for the y-variable are plotted against z-scores for the x-variable.\nA: Let's think step by step. Statement A is false because the slope of the regression line being exactly 1 can occur even when the two variables are not perfectly correlated. Statement B is false because uncorrelated variables regression lines can have slope zero. Statement C is false because correlation is symmetric in the two random variables. The answer is (D).\n\n", "high_school_us_history": "The following are multiple choice questions (with answers) about high school us history.\n\nQ: This question refers to the following information.\nI come not to urge personal claims, nor to seek individual benefits; I appear as the advocate of those who cannot plead their own cause; I come as the friend of those who are deserted, oppressed, and desolate. In the Providence of God, I am the voice of the maniac whose piercing cries from the dreary dungeons of your jails penetrate not your Halls of Legislation. I am the Hope of the poor crazed beings who pine in the cells, and stalls, and cages, and waste rooms of your poor-houses. I am the Revelation of hundreds of wailing, suffering creatures, hidden in your private dwellings, and in pens and cabins\u2014shut out, cut off from all healing influences, from all mind-restoring cares.\u2026 Could their melancholy histories be spread before you as revealed to my grieved spirit during the last three months, how promptly, how earnestly would you search out the most approved means of relief; how trifling, how insignificant, by comparison, would appear the sacrifices you are asked to make; how would a few dimes and dollars, gathered from each citizen, diminish in value as a possession, compared with the certain benefits and vast good to be secured for the suffering insane...by the consecration and application of a sufficient fund to the construction of a suitable hospital.\u2026\n\u2014Dorothea Dix, Memorial Soliciting a State Hospital for the Protection and Cure of the Insane,\nSubmitted to the General Assembly of North Carolina, November 1848\nDorothea Dix can best be compared to whom?\n(A) Abigail Adams (B) Clara Barton (C) Shirley Temple (D) Hillary Clinton\nA: Let's think step by step. We refer to Wikipedia articles on us history for help. Both Dorothea Dix and Clara barton are American nurses. The answer is (B).\n\nQ: This question refers to the following information.\n\"As our late Conduct at the Conestoga Manor and Lancaster have occasioned much Speculation & a great diversity of Sentiments in this and neighboring Governments; some vindicating & others condemning it; some charitably alleviating the Crime, & others maliciously painting it in the most odious & detestable Colours, we think it our duty to lay before the Publick, the whole Matter as it appeared, & still appears, to us. . . .\n\"If these things are not sufficient to prove an unjustifiable Attachment in the Quakers to the Indians Savages, a fixed Resolution to befriend them & an utter insensibility to human Distresses, let us consider a few more recent Facts. When we found the last Summer that we were likely to get no Assistance from the Government, some Volunteers went out at our own Expense, determined to drive our Enemies from our Borders; & when we came near to the great Island, we understood that a Number of their Warriors had gone out against our Frontiers. Upon this we returned and came up with them and fought with them at the Munfey Hill where we lost some of our Men & killed some of their Warriors & thereby saved our Frontiers from this Story in another Expedition. But no sooner had we destroyed their Provisions on the great Island, & ruined their trade with the good People at Bethlehem, but these very Indians, who were justly suspected of having murdered our Friends in Northampton County, were by the Influence of some Quakers taken under the Protection of the Government to screen them from the Resentments of the Friends and Relations of the Murdered, & to support them thro the Winter.\"\n\u2014\"Apology of the Paxton Boys\" (pamphlet), 1764 (Note: \"apology\" in this context should be read as an explanation, not an admission of guilt or regret.\nThe sentiments expressed in the explanation above reflect which of the ongoing tensions during the colonial period of American history?\n(A) Tensions between British policies and the aspirations of North American colonists. (B) Tensions between American Indians allied with the French and those allied with the British. (C) Tensions between freed African Americans and white planters. (D) Tensions between backcountry settlers and elites within colonial America.\nA: Let's think step by step. We refer to Wikipedia articles on us history for help. After the French and Indian War, the Scotch-Irish settlers attacked American Indians. After the attacks on the Conestoga, about 250 Paxton Boys present their grievances to the Pennsylvania legislature. As mentioned in the information, the Paxton Boys cited resentiment at local elites. The answer is (D).\n\nQ: This question refers to the following information.\nOur leaders talk about stopping aggression from the north, but this was a struggle among groups of Vietnamese until we intervened. We seem bent upon saving the Vietnamese from Ho Chi Minh even if we have to kill them and demolish their country to do it. As the native people survey bombed-out villages, women and children burned by napalm, rice crops destroyed and cities overrun with our military personnel, they are doubtless saying secretly of the Vietcong guerillas and of the American forces, \"A plague on both your houses.\" \u2026 Stop the bombing, north and south, end search and destroy offensive sweeps, and confine our military action to holding operations on the ground. Bombing the north has failed to halt or seriously check the flow of troops to the south and may, in fact, have prompted a much greater war effort by Hanoi.\n\u2014Senator George McGovern, \"The Lessons of Vietnam,\" April 25, 1967\nWhich of the following opinions from the 1960s most directly reflects the perspective of George McGovern's speech?\n(A) Americans must maximize their technological edge in Vietnam. (B) American bombing in Vietnam is step by step leading to progress in the war. (C) American bombing in Vietnam is a failure. (D) America must not give in to defeatism about the war in Vietnam.\nA: Let's think step by step. We refer to Wikipedia articles on us history for help. \"Stop the bombing\" and \"Bombing the north has failed to halt or seriously check the flow of troops to the south\" indicate that the perspective of George McGovern's speech is that Amerian bombing in Vietnam is a failure. The answer is (C).\n\nQ: This question refers to the following information.\n\"In the new Code of Laws which I suppose it will be necessary for you to make I desire you would Remember the Ladies, and be more generous and favorable to them than your ancestors. Do not put such unlimited power into the hands of the Husbands. Remember all Men would be tyrants if they could. If particular care and attention is not paid to the Ladies we are determined to foment a Rebellion, and will not hold ourselves bound by any Laws in which we have no voice, or Representation.\"\nAbigail Adams, in a letter to John Adams, 1776\n\"Special legislation for woman has placed us in a most anomalous position. Women invested with the rights of citizens in one section\u2014voters, jurors, office-holders\u2014crossing an imaginary line, are subjects in the next. In some States, a married woman may hold property and transact business in her own name; in others, her earnings belong to her husband. In some States, a woman may testify against her husband, sue and be sued in the courts; in others, she has no redress in case of damage to person, property, or character. In case of divorce on account of adultery in the husband, the innocent wife is held to possess no right to children or property, unless by special decree of the court. But in no State of the Union has the wife the right to her own person, or to any part of the joint earnings of the co-partnership during the life of her husband. In some States women may enter the law schools and practice in the courts; in others they are forbidden. In some universities girls enjoy equal educational advantages with boys, while many of the proudest institutions in the land deny them admittance, though the sons of China, Japan and Africa are welcomed there. But the privileges already granted in the several States are by no means secure.\"\nSusan B. Anthony, \"Declaration of Rights for Women,\" July 4, 1876\nThe sentiments expressed in the second excerpt by Susan B. Anthony are most likely in support of\n(A) the Equal Rights Amendment (B) universal suffrage (C) states' rights (D) prohibition\nA: Let's think step by step. We refer to Wikipedia articles on us history for help. The above information mentioned that women are in an anomalous position in terms of legislation. Women's earnings do not belong to themselves, or they cannot testify against her husbands. Susan believes women should have equal legal rights as men. The answer is (B).\n\nQ: This question refers to the following information.\n\"Society in every state is a blessing, but government even in its best state is but a necessary evil; in its worst state an intolerable one; for when we suffer, or are exposed to the same miseries by a government, which we might expect in a country without government, our calamity is heightened by reflecting that we furnish the means by which we suffer. Government, like dress, is the badge of lost innocence; the palaces of kings are built on the ruins of the bowers of paradise. For were the impulses of conscience clear, uniform, and irresistibly obeyed, man would need no other lawgiver; but that not being the case, he finds it necessary to surrender up a part of his property to furnish means for the protection of the rest; and this he is induced to do by the same prudence which in every other case advises him out of two evils to choose the least. Wherefore, security being the true design and end of government, it unanswerably follows that whatever form thereof appears most likely to ensure it to us, with the least expense and greatest benefit, is preferable to all others.\"\nThomas Paine, Common Sense, 1776\nWhich of the following \"miseries\" alluded to above were most condemned by Anti-Federalists of the post-Revolutionary era?\n(A) Organized response to Bacon's Rebellion (B) Federal response to Shays's Rebellion (C) Federal response to the Whiskey Rebellion (D) Federal response to Pontiac's Rebellion\nA: Let's think step by step. We refer to Wikipedia articles on us history for help. Anti-Federalists do not believe centralized government power, and suspect Washington's military response to Whiskey Rebellion. Bacon's Rebellion and Pontiac's Rebellion happen before the Revolution and they can be ruled out. The answer is (C).\n\n", "high_school_world_history": "The following are multiple choice questions (with answers) about high school world history.\n\nQ: This question refers to the following information.\n\"At least one of the [world's] societies would have to somehow enormously increase its productivity [in order to achieve global hegemony]. That quantum jump would have to be made before the various scientific, technological, agricultural, and industrial revolutions on which our post-quantum-leap world rests. It could only be accomplished by exploiting the ecosystems, mineral resources, and human assets of whole continents outside the lands of the society making the jump. Western Europe did just that by means of its brutality and guns and, more important, by geographical and ecological luck.\"\nCopyright \u00a9 2015 Cambridge University Press.\nAlfred Crosby, historian, Ecological Imperialism, 2004\nThe \"quantum jump\" mentioned in the passage most directly contributed to which of the following developments in the period 1450\u20131750 C.E.?\n(A) A breakdown in trade routes through the collapse of the established state structure (B) An increase in the population of the world through more plentiful supplies of food (C) The spread of Chinese and Indian belief systems across the world (D) An increase in social unrest\nA: Let's think step by step. We refer to Wikipedia articles on world history for help. The \"quantum jump\" mentioned in the passage refers to the conquest of the New World and the Columbian Exchange. Choice (A) and (C) did not happen in history. Choice (C) refers to the human assets. The answer is (B).\n\nQ: This question refers to the following information.\n\"The struggle against neo-colonialism is not aimed at excluding the capital of the developed world from operating in less developed countries. It is aimed at preventing the financial power of the developed countries being used in such a way as to impoverish the less developed.\nNon-alignment, as practiced by Ghana and many other countries, is based on co-operation with all States whether they be capitalist, socialist or have a mixed economy. Such a policy, therefore, involves foreign investment from capitalist countries, but it must be invested in accordance with a national plan drawn up by the government of the non-aligned State with its own interests in mind. The issue is not what return the foreign investor receives on his investments\u2026The question is one of power. A State in the grip of neo-colonialism is not master of its own destiny.\"\nKwame Nkrumah, Neo-Colonialism, 1965\nWhich of the following provides the best context for Nkrumah's writings?\n(A) The Industrial Revolution (B) Decolonization (C) Regional Free Trade Associations (D) Autarky\nA: Let's think step by step. We refer to Wikipedia articles on world history for help. The passage expresses a point that the successful fight against neo-colonialism were in danger and the newly independent nations like Ghana may be re-colonized via financial power of the developed countries. The answer is (B).\n\nQ: This question refers to the following information.\n\"Indeed, as both the fatwas of distinguished [scholars] who base their opinion on reason and tradition alike and the consensus of the Sunni community agree that the ancient obligation of extirpation, extermination, and expulsion of evil innovation must be the aim of our exalted aspiration, for \"Religious zeal is a victory for the Faith of God the Beneficent\"; then, in accordance with the words of the Prophet (Peace upon him!) \"Whosoever introduces evil innovation into our order must be expelled\" and \"Whosoever does aught against our order must be expelled,\" action has become necessary and exigent\u2026\"\nLetter from Ottoman Sultan Selim I to Safavid Shah Ismail I, 1514\nThe letter from Selim I is most clearly an example of which of the following?\n(A) The maintenance of military supremacy at all costs (B) Expanding tensions between religious sects (C) Factors that brought about the collapse of the Ottoman Empire (D) Peacemaking efforts among the Islamic empires\nA: Let's think step by step. We refer to Wikipedia articles on world history for help. The passage is an example of expanding tensions between Selim and Ismail. In the passage the Selim references the fatwa and the consensus of the Sunni community to against whosoever introduces evil. The answer is (B).\n\nQ: This question refers to the following information.\n\"The real grievance of the worker is the insecurity of his existence; he is not sure that he will always have work, he is not sure that he will always be healthy, and he foresees that he will one day be old and unfit to work. If he falls into poverty, even if only through a prolonged illness, he is then completely helpless, exam_ins to his own devices, and society does not currently recognize any real obligation towards him beyond the usual help for the poor, even if he has been working all the time ever so faithfully and diligently. The usual help for the poor, however, leaves a lot to be desired, especially in large cities, where it is very much worse than in the country.\"\nOtto von Bismarck, 1884\nOtto von Bismarck likely made this speech in reaction to which of the following issues?\n(A) Social acceptance of child labor (B) Declining life expectancy in Germany (C) Criticisms of German trade tariffs (D) Negative effects attributed to industrial capitalism\nA: Let's think step by step. We refer to Wikipedia articles on world history for help. The passage talks about the grievance of the work under the industrial capitalism. The answer is (D).\n\nQ: This question refers to the following information.\nHe contains all works and desires and all perfumes and all tastes. He enfolds the whole universe and in silence is loving to all. This is the Spirit that is in my heart, this is Brahman. To him I shall come when I go beyond this life, and to him will come he who has faith and doubts not.\n\u2014The Upanishads, India, c. 1000 BCE\nTo which religion does the speaker most likely belong?\n(A) Hinduism (B) Buddhism (C) Shintoism (D) Zoroastrianism\nA: Let's think step by step. We refer to Wikipedia articles on world history for help. Brahman refers to the ultimate reality of all things in the Hindu religion. In contrast, Buddhism does not have a concept of supreme God. The answer is (A).\n\n", "human_aging": "The following are multiple choice questions (with answers) about human aging.\n\nQ: All other things being equal, which of the following persons is more likely to show osteoporosis?\n(A) An older Hispanic American woman (B) An older African American woman (C) An older Asian American woman (D) An older Native American woman\nA: Let's think step by step. We refer to Wikipedia articles on human aging for help. Although osteoporosis can occur at any age, the risk is higher for older people. It is most common in Asian and non-Hispanic white women. The answer is (C).\n\nQ: The finding that adults tend to remember events from their adolescence better than from other periods in their lives is referred to as the\n(A) Adolescence advantage (B) Reminiscence bump (C) Memorial memorial (D) Quadratic retrieval spike\nA: Let's think step by step. We refer to Wikipedia articles on human aging for help. Reminiscence bump is a phenomenon that older adults tend to recollect events during their young ages. People usually have a period of childhood amnesia from birth to around age 5, and a reminiscence bump between 10 and 30. The answer is (B).\n\nQ: Which element in tobacco smoke is responsible for cancers?\n(A) Nicotine (B) Tar (C) Carbon monoxide (D) Smoke particles\nA: Let's think step by step. We refer to Wikipedia articles on human aging for help. The benzene, acrylamide and acrylonitrile in tar interact with the lungs and cause DNA mutations in cells of the lungs, and lead to cancer. The answer is (B).\n\nQ: When older adults move to a new state after retirement, which of the following is the more likely destination?\n(A) Texas (B) California (C) Hawaii (D) Vermont\nA: Let's think step by step. We refer to Wikipedia articles on human aging for help. Texas does not have state tax, and has low cost of living compared with the other three options. The answer is (A).\n\n", "human_sexuality": "The following are multiple choice questions (with answers) about human sexuality.\n\nQ: The most common disorder among men who seek sexual therapy is:\n(A) premature ejaculation (B) inhibited ejaculation (C) erectile disorder (D) ejaculatory disorder\nA: Let's think step by step. We refer to Wikipedia articles on human sexuality for help. The most common type of arousal disorder for men is erectile dysfunction, meaning that a person is not able to get any physical satisfaction from sexual activity although he may be interested in it. The answer is (C).\n\nQ: A woman who knows she has active herpes and untreated syphilis but continues to have sex without informing her partners of her condition has, in psychoanalytic terms:\n(A) a strong ego (B) a weak superego (C) a weak id (D) a strong superego\nA: Let's think step by step. We refer to Wikipedia articles on human sexuality for help. A person with weak superego tends to be delinquent, criminal or have antisocial personality. The action of the woman who knows she has active venereal disease but still have sex with her partners indicate she may has antisocial personality. The answer is (B).\n\nQ: The nature of homosexual activities that occur during preadolescence include all but which of the following?\n(A) sexual intercourse (B) circle jerks (C) exhibitionism (D) touching each other's genitals\nA: Let's think step by step. We refer to Wikipedia articles on human sexuality for help. Sexual intercourse is prohibited during preadolescence. The answer is (A).\n\nQ: Women's ability to have multiple orgasms is primarily due to:\n(A) the fact that they do not have a refractory period. (B) the response of the inner layers of the vagina. (C) having alternating orgasms in different locations. (D) the G-Spot.\nA: Let's think step by step. We refer to Wikipedia articles on human sexuality for help. The refractory period is the time when a person is not able to be erect or is not interested in sex. The answer is (A).\n\nQ: Morning sickness is typically a problem:\n(A) during the first trimester (B) during the second trimester (C) during the third trimester (D) all through the pregnancy\nA: Let's think step by step. We refer to Wikipedia articles on human sexuality for help. Morning sickness usually begins by nine weeks after conception, corresponding to the first trimester. The answer is (A).\n\n", "international_law": "The following are multiple choice questions (with answers) about international law.\n\nQ: How the consent to be bound of a State may be expressed?\n(A) The consent of a State to be bound is expressed only by ratification (B) The consent of a state to be bound by a treaty may be expressed by signature, ratification, acceptance, approval or accession (C) The consent of a State to be bound is expressed by signature (D) The consent of a State to be bound is expressed by whatever means they choose\nA: Let's think step by step. We refer to Wikipedia articles on international law for help. Article 11 of Vienna Convention on the Law of Treaties signed in 1969 states that \"the consent of a State to be bound by a treaty may be expressed by signature, exchange of instruments constituting a treaty, ratification, acceptance, approval or accession, or by any other means if so agreed.\" (B) is the most precise and accurate answer. The answer is (B).\n\nQ: What is the judge ad hoc?\n(A) If a party to a contentious case before the ICJ does not have a national sitting as judge, it is entitled to nominate someone as a judge solely for that case, with the title of judge ad hoc (B) Judge ad hoc is the member of the bench of the ICJ with a casting vote (C) Judge ad hoc is a surrogate judge, in case a judge is disqualified or passes away (D) Judge ad hoc is the judge that each party will always nominate in every contentious case\nA: Let's think step by step. We refer to Wikipedia articles on international law for help. As \"ad hoc\" implies, a judge ad hoc is appointed only for a specific case or period, when a party to a contentious case before the International Court of Justice does not have a regular national sitting as judge. The answer is (A).\n\nQ: When 'consent' can serve as a circumstance precluding the wrongfulness of a State conduct?\n(A) Consent can serve as a circumstance precluding the wrongfulness whenever it is given (B) Consent can never serve as a circumstance precluding wrongfulness (C) Consent can serve as a circumstance precluding wrongfulness, provided the consent is valid and to the extent that the conduct remains within the limits of the consent given (D) Consent can always serve as a circumstance precluding wrongfulness, no matter which organ of the State gives it\nA: Let's think step by step. We refer to Wikipedia articles on international law for help. Valid consent can serve as a circumstance precluding the wrongfulness of a State conduct if the conduct remains within the limits of that consent, according to Chapter V of the Responsibility of States for Internationally Wrongful Acts, 2001, United Nations. The answer is (C).\n\nQ: Would a reservation to the definition of torture in the ICCPR be acceptable in contemporary practice?\n(A) This is an acceptable reservation if the reserving country's legislation employs a different definition (B) This is an unacceptable reservation because it contravenes the object and purpose of the ICCPR (C) This is an unacceptable reservation because the definition of torture in the ICCPR is consistent with customary international law (D) This is an acceptable reservation because under general international law States have the right to enter reservations to treaties\nA: Let's think step by step. We refer to Wikipedia articles on international law for help. For it contravenes the object and purpose of the ICCPR, this is an unacceptable reservation in contemporary practice. The answer is (B).\n\nQ: What types of force does Article 2(4) of the UN Charter prohibit?\n(A) Article 2(4) encompasses only armed force (B) Article 2(4) encompasses all types of force, including sanctions (C) Article 2(4) encompasses all interference in the domestic affairs of States (D) Article 2(4) encompasses force directed only against a State's territorial integrity\nA: Let's think step by step. We refer to Wikipedia articles on international law for help. Article 2(4) of the UN Charter prohibits states from using armed forces in their international relations. The answer is (A).\n\n", "jurisprudence": "The following are multiple choice questions (with answers) about jurisprudence.\n\nQ: Iverson Jewelers wrote a letter to Miller, 'We have received an exceptionally fine self winding Rolox watch which we will sell to you at a very favorable price.'\n(A) The letter is an offer to sell (B) A valid offer cannot be made by letter. (C) The letter contains a valid offer which will terminate within a reasonable time. (D) The letter lacks one of the essential elements of an offer.\nA: Let's think step by step. We refer to Wikipedia articles on jurisprudence for help. An offer shows the intent to enter into a mutually-beneficial contract with specific terms. An offer can be made by a letter. While this letter indicates the willingness to sell, the lack of specific terms, such as transaction price and offer expiration date, makes it an incomplete offer. The answer is (D).\n\nQ: Functions of the law include all but which of the following?\n(A) maximizing individual freedom (B) providing a basis for compromise (C) keeping the peace (D) promoting the principles of the free enterprise system\nA: Let's think step by step. We refer to Wikipedia articles on jurisprudence for help. Laws are fundamentally about helping resolve disputes between individuals, and therefore essential for maximizing individual freedom, providing a basis for compromise, and keeping the peace. The answer is (D).\n\nQ: The ________ School of jurisprudence postulates that the law is based on what is \"correct.\"\n(A) Natural Law (B) Analytical (C) Historical (D) Sociological\nA: Let's think step by step. We refer to Wikipedia articles on jurisprudence for help. Natural Law School of jurisprudence focuses on the laws of nature, and states that the law should be based on ethics, morals, and what is \"correct\". Analytical deals with the law as it already exists, Historical postulates that the law was found and not made, and Sociological studies how the law and society impact each other. The answer is (A).\n\nQ: Which word best summarizes Weber's explanation of the development of formally rational law?\n(A) Authority. (B) Charisma. (C) Co-operation. (D) Capitalism.\nA: Let's think step by step. We refer to Wikipedia articles on jurisprudence for help. Weber explained the development of formal rationality in laws as how the modern society moved from tradition to rationality, where people decide actions based less on how they were culturally done and more on expected utilities. How rational individuals optimize efficiency of accomplishing tasks for higher rewards is a core principle of Capitalism. The answer is (D).\n\nQ: Which position does Rawls claim is the least likely to be adopted by the POP (people in the original position)?\n(A) The POP would choose equality above liberty. (B) The POP would opt for the 'maximin' strategy. (C) The POP would opt for the 'difference principle'. (D) The POP would reject the 'system of natural liberty.'\nA: Let's think step by step. We refer to Wikipedia articles on jurisprudence for help. The POP would opt for the 'maximin' strategy, opt for the 'difference principle', and reject the 'system of natural liberty', but the POP would not choose equality above liberty, since the POP assume both equal and free citizens. The answer is (A).\n\n", "logical_fallacies": "The following are multiple choice questions (with answers) about logical fallacies.\n\nQ: When an arguer causes confusion during refutation because of real or feigned lack of an ability to engage in refutation, that arguer may have committed the fallacy of\n(A) poor sportsmanship (B) appeal to compassion (C) argument against the person (D) ignorance of refutation\nA: Let's think step by step. We refer to Wikipedia articles on logical fallacies for help. Ignorance of refutation, one of Aristotle's original list of logical fallacies in his Organon, is when someone causes confusion in an argument through real or feigned inability to engage in refutation, in order to win the argument. The answer is (D).\n\nQ: The complex question fallacy consists of\n(A) arguing something is inferior just because it doesn't do something it was never intended to do. (B) including more than one claim in the proposition and treating proof for one claim as proof for all the claims. (C) drawing a conclusion before examining the evidence, and only considering evidence that supports that conclusion. (D) asking a question that includes either an unproven assumption or more than one question, thus making a straightforward yes or no answer meaningless.\nA: Let's think step by step. We refer to Wikipedia articles on logical fallacies for help. The complex question fallacy is when someone makes a single yes or no answer to a question meaningless, by including either an unproven assumption or many questions. The latter is also known as the many questions fallacy. The answer is (D).\n\nQ: Arguing that what is true of the parts must be true of the whole is the fallacy of...\n(A) Division (B) Composition (C) Appeal to the person (D) Appeal to ignorance\nA: Let's think step by step. We refer to Wikipedia articles on logical fallacies for help. Fallacy of composition occurs when someone argues what is true of the parts must be true of the whole. The answer is (B).\n\nQ: Which of the following is true of a valid categorical syllogism?\n(A) The minor premise must deny the antecedent (B) The major premise must affirm the consequent (C) The middle term must be used in at least one premise in a universal or unqualified sense (D) All of the above\nA: Let's think step by step. We refer to Wikipedia articles on logical fallacies for help. A valid categorical syllogism must satisfy several conditions: (1) the syllogism must have exactly three terms (2) every term of the syllogism must be used twice exactly, (3) a term may be used only once in any premise, and (4) the middle term must be used in at least one premise in a universal or unqualified sense, etc. Only (C) is true. The answer is (C).\n\nQ: If someone attacks the character of an opposing arguer, instead of responding to that opponent's arguments, the first person has probably committed which of the following fallacies?\n(A) tu quoque (B) horse laugh (C) argument against the person (D) ignoratio elenchi\nA: Let's think step by step. We refer to Wikipedia articles on logical fallacies for help. The argument against the person fallacy occurs when someone irrelevantly attacks the character of an opposing arguer, instead of addressing that opponent's arguments. The answer is (C).\n\n", "machine_learning": "The following are multiple choice questions (with answers) about machine learning.\n\nQ: Which image data augmentation is most common for natural images?\n(A) random crop and horizontal flip (B) random crop and vertical flip (C) posterization (D) dithering\nA: Let's think step by step. Data augmentation is used to increase the diversity of images in the training dataset. It is important that natural images are kept natural after being augmented. Vertical flips of images are not natural, so (B) is false. Posterization makes the image look like a poster and and dithering increases color depth. None of these two preserve the natural property. The only natural data augmentation technique is (A). The answer is (A).\n\nQ: Traditionally, when we have a real-valued input attribute during decision-tree learning we consider a binary split according to whether the attribute is above or below some threshold. Pat suggests that instead we should just have a multiway split with one branch for each of the distinct values of the attribute. From the list below choose the single biggest problem with Pat\u2019s suggestion:\n(A) It is too computationally expensive. (B) It would probably result in a decision tree that scores badly on the training set and a testset. (C) It would probably result in a decision tree that scores well on the training set but badly on a testset. (D) It would probably result in a decision tree that scores well on a testset but badly on a training set.\nA: Let's think step by step. Because the input is real valued, it is unlikely that the same values appear both at training and test time. This means that while such a decision tree could yield good performance on the training data, when evaluated on the test data it will perform badly because the decision tree won\u2019t know what to do with numbers that did not appear in the training data. The answer is (C).\n\nQ: You are reviewing papers for the World\u2019s Fanciest Machine Learning Conference, and you see submissions with the following claims. Which ones would you consider accepting?\n(A) My method achieves a training error lower than all previous methods! (B) My method achieves a test error lower than all previous methods! (Footnote: When regularisation parameter \u03bb is chosen so as to minimise test error.) (C) My method achieves a test error lower than all previous methods! (Footnote: When regularisation parameter \u03bb is chosen so as to minimise cross-validaton error.) (D) My method achieves a cross-validation error lower than all previous methods! (Footnote: When regularisation parameter \u03bb is chosen so as to minimise cross-validaton error.)\nA: Let's think step by step. In machine learning, we train with some data and fixed hyperparameters and the training error can be arbitrarily low, so (A) can\u2019t be right. Then, one compares different hyperparameters by selecting the model with the lowest cross-validation error, this means that (B) and (D) are not the right procedure. The only relevant number after these is the test error and thus (C) is the right answer. The answer is (C).\n\nQ: A 6-sided die is rolled 15 times and the results are: side 1 comes up 0 times; side 2: 1 time; side 3: 2 times; side 4: 3 times; side 5: 4 times; side 6: 5 times. Based on these results, what is the probability of side 3 coming up when using Add-1 Smoothing?\n(A) 2.0/15 (B) 1.0/7 (C) 3.0/16 (D) 1.0/5\nA: Let's think step by step. Add-1 smoothing adds the value of one to the different counts and then normalizes the probabilities accordingly. The counts after adding one will be: side 1 comes up 1 time; side 2: 2 times; side 3: 3 times; side 4: 4 times; side 5: 5 times; side 6: 6 times. The number of sum one die rolls will be 21, so the probability of drawing a three is 3/21 = 1/7. The answer is (B).\n\nQ: To achieve an 0/1 loss estimate that is less than 1 percent of the true 0/1 loss (with probability 95%), according to Hoeffding's inequality the IID test set must have how many examples?\n(A) around 10 examples (B) around 100 examples (C) between 100 and 500 examples (D) more than 1000 examples\nA: Let's think step by step. By the Hoeffding\u2019s inequality, we expect that with 95% probability the in-sample and out-of-sample errors differ by epsilon when we have N samples if 2 exp(-2 epsilon^2 N)<0.05, this implies that N > -1/(2*epsilon**2) log ( 0.05/2 )= log (40)*5000. Since log(40)>1, we have that one needs more than 1000 examples. The answer is (D).\n\n", "management": "The following are multiple choice questions (with answers) about management.\n\nQ: How can organisational structures that are characterised by democratic and inclusive styles of management be described?\n(A) Hierarchical (B) Bureaucratic (C) Flat (D) Functional\nA: Let's think step by step. We refer to Wikipedia articles on management for help. Flat organizational structures are characterized by democratic and inclusive styles of management, and have few (if any) levels of management between the workers and managers. The answer is (C).\n\nQ: Hygiene factors are associated with which writer?\n(A) Frederick Hertzberg (B) D.C. McClelland (C) Abraham Maslow (D) Douglas McGregor\nA: Let's think step by step. We refer to Wikipedia articles on management for help. Hygiene factors include compensation, company policies, supervision, interpersonal relations, and work environments. Hertzberg lists them as factors that cannot motivate employees but can minimize job dissatisfaction. The answer is (A).\n\nQ: What characteristic is not a key feature of the 'open systems' model of management?\n(A) Morale (B) Innovation (C) Growth resource (D) Adaptation\nA: Let's think step by step. We refer to Wikipedia articles on management for help. The key characteristics of an open system in management include innovation, growth resource, and adaption, but do not include morale. The answer is (A).\n\nQ: Which element of the cultural web forms regalia?\n(A) Symbols (B) Rituals and routines (C) Power structures (D) Control systems\nA: Let's think step by step. We refer to Wikipedia articles on management for help. The cultural web is a tool for mapping an organization's culture, where symbols form the regalia that visually expresses the values that the organization holds as important. The answer is (A).\n\nQ: What are the two main dimensions of the Ohio Studies into leadership?\n(A) Starting position and end position (B) Initial environment and changed environment (C) Organisational structure and conditioning (D) Initiating structure and considerations\nA: Let's think step by step. We refer to Wikipedia articles on management for help. The Ohio State Leadership Studies conducted in the 1940s identified initiating structure and consideration as the two main dimensions of leader behavior. The answer is (D).\n\n", "marketing": "The following are multiple choice questions (with answers) about marketing.\n\nQ: Although the content and quality can be as controlled as direct mail, response rates of this medium are lower because of the lack of a personal address mechanism. This media format is known as:\n(A) Care lines. (B) Direct mail. (C) Inserts. (D) Door to door.\nA: Let's think step by step. We refer to Wikipedia articles on marketing for help. Door to door marketing delivers non-addressed items within all buildings within a geographic area. While it can control the content and quality as well as direct mail marketing, its response rate is lower because of the lack of a personal address mechanism. The answer is (D).\n\nQ: In an organization, the group of people tasked with buying decisions is referred to as the _______________.\n(A) Outsourcing unit. (B) Procurement centre. (C) Chief executive unit. (D) Decision-making unit.\nA: Let's think step by step. We refer to Wikipedia articles on marketing for help. In an organization, the group of the people tasked with buying decision is referred to as the decision-making unit. The answer is (D).\n\nQ: The single group within society that is most vulnerable to reference group influence is:\n(A) The older consumer who feels somewhat left out of things. (B) The married women, many of whom feel a need for stability in their lives. (C) New immigrants who really want to assimilate into their new culture. (D) Children, who base most of their buying decisions on outside influences.\nA: Let's think step by step. We refer to Wikipedia articles on marketing for help. Children, who mostly based their buying decisions on outside influences, are the single group within society that is more vulnerable to reference group influence. The answer is (D).\n\nQ: Which of the following is an assumption in Maslow's hierarchy of needs?\n(A) Needs are dependent on culture and also on social class. (B) Lower-level needs must be at least partially satisfied before higher needs can affect behaviour. (C) Needs are not prioritized or arranged in any particular order. (D) Satisfied needs are motivators, and new needs emerge when current needs remain unmet.\nA: Let's think step by step. We refer to Wikipedia articles on marketing for help. Maslow's hierarchy of needs, from the bottom upwards, are physiological (food and clothing), safety, love and belonging needs, esteem, and self-actualization. Lower-level needs must be at least partially satisfied before higher ones can affect behavior. The answer is (B).\n\nQ: _____________ is a natural outcome when combining demographic and geographic variables.\n(A) Geodemographics (B) Product differentiation. (C) ANSOFF matrix. (D) Brand management.\nA: Let's think step by step. We refer to Wikipedia articles on marketing for help. Geodemographics is a natural outcome when combining demographic and geographic variables. The answer is (A).\n\n", "medical_genetics": "The following are multiple choice questions (with answers) about medical genetics.\n\nQ: The stage of meiosis in which chromosomes pair and cross over is:\n(A) prophase I (B) metaphase I (C) prophase II (D) metaphase II\nA: Let's think step by step. We refer to Wikipedia articles on medical genetics for help. Prophase I is the stage of meiosis where homologous chromosomes pair with each other and exchange genetic material. The answer is (A).\n\nQ: DNA ligase is\n(A) an enzyme that joins fragments in normal DNA replication (B) an enzyme of bacterial origin which cuts DNA at defined base sequences (C) an enzyme that facilitates transcription of specific genes (D) an enzyme which limits the level to which a particular nutrient reaches\nA: Let's think step by step. We refer to Wikipedia articles on medical genetics for help. DNA ligase is a type of enzyme (EC 6.5.1.1) responsible for joining DNA strands together by catalyzing a phosphodiester bond. The answer is (A).\n\nQ: Which of the following conditions does not show multifactorial inheritance?\n(A) Pyloric stenosis (B) Schizophrenia (C) Spina bifida (neural tube defects) (D) Marfan syndrome\nA: Let's think step by step. We refer to Wikipedia articles on medical genetics for help. Multifactorial inheritance is when more than a single factor is responsible for causing a given trait or health problem. Genes cannot be the only factor. Marfan syndrome, on the other hand, requires only one abnormal copy of the of the Marfan gene, from one parent, to inherit the trait. The answer is (D).\n\nQ: A gene showing codominance\n(A) has both alleles independently expressed in the heterozygote (B) has one allele dominant to the other (C) has alleles tightly linked on the same chromosome (D) has alleles expressed at the same time in development\nA: Let's think step by step. We refer to Wikipedia articles on medical genetics for help. Codominance, as it relates to genetics, refers to a type of genetic inheritance where the phenotype of both the parents is easily observed in the offspring. A heterozygote is an individual having two different alleles of a gene. The answer is (A).\n\nQ: Large triplet repeat expansions can be detected by:\n(A) polymerase chain reaction. (B) single strand conformational polymorphism analysis. (C) Southern blotting. (D) Western blotting.\nA: Let's think step by step. We refer to Wikipedia articles on medical genetics for help. A Southern blot is a method in molecular biology for detecting specific DNA sequences in a sample. Large triplet repeat expansions are usually detected with this method. The answer is (C).\n\n", "miscellaneous": "The following are multiple choice questions (with answers) about miscellaneous.\n\nQ: Which of these songs was a Top 10 hit for the rock band The Police?\n(A) 'Radio Ga-Ga' (B) 'Ob-la-di Ob-la-da' (C) 'De Do Do Do De Da Da Da' (D) 'In-a-Gadda-Da-Vida'\nA: Let's think step by step. We refer to Wikipedia for help. Radio Ga-Ga is by Queen. Ob-la-di Ob-la-da is by The Beatles. And In-a-Gadda-Da-Vida is by Iron Butterfly. Leaving 'De Do Do Do De Da Da Da' as the only song by The Police, and also a Top 10 hit. The answer is (C).\n\nQ: What place is named in the title of the 1979 live album by rock legends Cheap Trick?\n(A) Budapest (B) Budokan (C) Bhutan (D) Britain\nA: Let's think step by step. We refer to Wikipedia for help. Nippon Budokan is an indoor arena in Tokyo, Japan renowned for hosting rock music concerts including Cheap Trick in 1978. 'Cheap Trick at Budokan' became the name of their album. The answer is (B).\n\nQ: What is produced during photosynthesis?\n(A) hydrogen (B) nylon (C) oxygen (D) light\nA: Let's think step by step. We refer to Wikipedia for help. Photosynthesis is the process in which green plants use the green pigment chlorophyll to synthesize foods with water and carbon dioxide. Oxygen is the byproduct of this process. The answer is (C).\n\nQ: Who is the shortest man to ever win an NBA slam dunk competition?\n(A) Anthony 'Spud' Webb (B) Michael 'Air' Jordan (C) Tyrone 'Muggsy' Bogues (D) Julius 'Dr J' Erving\nA: Let's think step by step. We refer to Wikipedia for help. In 1986, Spud Webb, standing only 5'7\" became the shortest NBA player in history to win an official slam dunk contest. The answer is (A).\n\nQ: How many axles does a standard automobile have?\n(A) one (B) two (C) four (D) eight\nA: Let's think step by step. We refer to Wikipedia for help. Most cars have two axles to rotate the wheels.. The answer is (B).\n\n", "moral_disputes": "The following are multiple choice questions (with answers) about moral disputes.\n\nQ: Baron admits that the versions of the ticking bomb hypothetical she discusses are \"stunningly stupid,\" but she claims this is actually evidence of\n(A) the stupidity of most traditional philosophical examples. (B) a general lack of intelligence among people with advanced degrees. (C) the wrongness of torture. (D) the readiness on the part of many intelligent people to see torture as the best solution to deal with terrorism.\nA: Let's think step by step. We refer to Wikipedia articles on moral disputes for help. The ticking bomb hypothetical poses a problem where many people will die to an exploding bomb, if the hypothetical terrorist does not disclose how to defuse it. Baron sees this hypothetical as silly, but its prevalence does suggest intelligent people, particularly utilitarians, see torture as justifiable to save the lives in this scenario. The answer is (D).\n\nQ: A fertilized ovum is also known as\n(A) a zygote. (B) an embryo. (C) a viability. (D) a blastocyst.\nA: Let's think step by step. We refer to Wikipedia articles on moral disputes for help. Once a single sperm penetrates the layers of an egg to form a new cell, that cell is called a zygote. The answer is (A).\n\nQ: Pence compares six different cases of reproduction, from natural twinning to SCNT. What conclusion does he draw from this comparison?\n(A) SCNT is not a different kind of reproduction because there are no morally relevant differences between it and other permissible means of reproduction. (B) Because there is a low risk of harm for natural twinning, there will be a low risk of harm for SCNT. (C) both A and B (D) neither A nor B\nA: Let's think step by step. We refer to Wikipedia articles on moral disputes for help. SCNT, known as 'somatic cell nuclear transfer', is a technique in which the nucleus of a somatic (body) cell is transferred to the cytoplasm of an enucleated egg. This is used in farm animal cloning. Pence may compare this to other processes of reproduction, and see little moral separation. The answer is (A).\n\nQ: Which of the following is an intrinsic good, according to Singer?\n(A) being in the upper class (B) the ability to learn and solve problems (C) being more beautiful than most of the population (D) being shorter than average\nA: Let's think step by step. We refer to Wikipedia articles on moral disputes for help. Bioethicist Peter Singer sees intrinsic value as innate values conferred by oneself, for oneself. Innanimate objects can be beautiful, short, or have some valuable criteria, but capabilities are intrinsically good. The answer is (B).\n\nQ: According to Metz, what is wrong with consequentialist arguments against capital punishment based on African values?\n(A) It is unclear as of yet whether or not capital punishment deters harm to the community. (B) It is unclear as of yet whether or not capital punishment deters harm to any individuals. (C) Consequentialism is not supported by African values. (D) Even though consequentialism is supported by African values, no consequentialist arguments framed in terms of African values have been offered.\nA: Let's think step by step. We refer to Wikipedia articles on moral disputes for help. Thaddeus Metz is a humanities research professor focusing on ethical philosophy in South Africa. Metz has written the death penalty is unjustified as it treats individuals as incapable of communal relations. It is unclear that capital punishment is to the benefit of, or a deterrent of harm to the community. The answer is (A).\n\n", "moral_scenarios": "The following are multiple choice questions (with answers) about moral scenarios.\n\nQ: For which of these two scenarios does the main character (who uses I/me/my) do something clearly morally wrong, according to ordinary moral standards in the US as of 2020? Scenario 1 | I took his wallet after he harassed me at the bar. Scenario 2 | I made love to my newborn baby boy.\n(A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not wrong\nA: Let's think step by step. We refer to Wikipedia articles on moral scenarios for help. Stealing from someone is not considered an ethical response to harrassment. And having sex with a minor is egregiously unethical. The answer is (A).\n\nQ: For which of these two scenarios does the main character (who uses I/me/my) do something clearly morally wrong, according to ordinary moral standards in the US as of 2020? Scenario 1 | I asked for a cigarette from the guy outside the store. Scenario 2 | When the security guard left his post, I pocketed a rare coin.\n(A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not wrong\nA: Let's think step by step. We refer to Wikipedia articles on moral scenarios for help. Asking for something is not morally wrong. Stealing is morally wrong. The answer is (C).\n\nQ: For which of these two scenarios does the main character (who uses I/me/my) do something clearly morally wrong, according to ordinary moral standards in the US as of 2020? Scenario 1 | During the house's demolition, I pushed him into a wall. Scenario 2 | When I choose a babysitter for my twin nieces, I don't care about her racial background.\n(A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not wrong\nA: Let's think step by step. We refer to Wikipedia articles on moral scenarios for help. Physical violence is morally wrong. Being open to hiring people from different racial backgrounds is an expected moral stance. The answer is (B).\n\nQ: For which of these two scenarios does the main character (who uses I/me/my) do something clearly morally wrong, according to ordinary moral standards in the US as of 2020? Scenario 1 | I bought an old pick-up truck for work. Scenario 2 | I looked up at the gorgeous waterfall and threw a rock at it.\n(A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not wrong\nA: Let's think step by step. We refer to Wikipedia articles on moral scenarios for help. Neither of these are wrong as no one is wronged by buing a pick-up truck or throwing a rock into a stream. The answer is (D).\n\nQ: For which of these two scenarios does the main character (who uses I/me/my) do something clearly morally wrong, according to ordinary moral standards in the US as of 2020? Scenario 1 | I really loved my newborn baby boy. Scenario 2 | The new girl is embarrassed to live in a trailer park, so I invite the class to her house.\n(A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not wrong\nA: Let's think step by step. We refer to Wikipedia articles on moral scenarios for help. Loving someone is not wrong. However, exposing something that someone is embarrassed about could be considered quite mean. The answer is (C).\n\n", "nutrition": "The following are multiple choice questions (with answers) about nutrition.\n\nQ: What is the first-line drug for patients with type 2 diabetes and obesity, as of 2020?\n(A) Acarbose (B) Metformin (C) Sulphonylureas (D) Insulin\nA: Let's think step by step. We refer to Wikipedia articles on nutrition for help. Metformin (Fortamet, Glumetza, or others) is usually the first medication prescribed for type 2 diabetes, as well as obesity. It works by lowering glucose production in the liver and improving the body's sensitivity to insulin. The answer is (B).\n\nQ: Which of the following statements is correct (according to knowledge in 2020)?\n(A) Consumers with phenylketonuria must avoid the consumption of the sweetener aspartame (B) Consumers with phenylketonuria must avoid the consumption of the sweetener saccharin (C) Consumers with phenylketonuria must avoid the consumption of the sweetener sucralose (D) Consumers with phenylketonuria must avoid the consumption of the sweetener acesulfame K\nA: Let's think step by step. We refer to Wikipedia articles on nutrition for help. People with phenylketonuria (PKU) cannot break down the amino acid phenylalanine. As it builds up in the blood and brain it can lead to brain damage. People with PKU should avoid foods that are converted to phenylalanine in the body, such as aspartame. The answer is (A).\n\nQ: Which of the following statements about iodine is correct, as of 2020?\n(A) 50% of adults consume iodine at levels below the RNI (B) Dairy products are a poor source of iodine (C) The iodine content of organic milk is generally lower that the level in non-organic milk (D) UK dietary reference values recommend an increase in iodine intake in pregnancy\nA: Let's think step by step. We refer to Wikipedia articles on nutrition for help. Organic milk usually has less iodine content than non-organic milk. The answer is (C).\n\nQ: Which of the following is the most plausible explanation for the protective effect of dietary fibre against cancer of the colon, as of 2020?\n(A) Propionic acid, formed during colonic fibre fermentation inhibits liver fatty acid synthesis (B) Butyric acid, formed during colonic fibre fermentation stimulates \"silencing\" of the SLC5A8 tumour suppressor gene (C) None of these options are correct (D) Butyric acid, formed during colonic fibre fermentation stimulates anti-oxidant defences in the colon\nA: Let's think step by step. We refer to Wikipedia articles on nutrition for help. Dietary fibre is inversely proportional to the risk of colorectal cancer. This is presumed because butyric acid (BA) stimulates antioxidants which help protect the colon from cancerous tumors. The answer is (D).\n\nQ: In a cohort study, the risk ratio of developing diabetes was 0.86 when comparing consumers of tea (the exposed) to those who did not drink tea (the unexposed). Which one statement is correct (according to knowledge in 2020)?\n(A) The tea drinkers have lower risk of developing diabetes. (B) The tea drinkers have higher risk of developing diabetes. (C) Based on the information given we cannot tell if the observed difference in disease risk is the result of chance. (D) The risk ratio is close to the value one, so there is no difference in disease risk between the two groups.\nA: Let's think step by step. We refer to Wikipedia articles on nutrition for help. The risk ratio is not sufficiently reduced that it could not be explained by random chance given the studies sample size. The answer is (C).\n\n", "philosophy": "The following are multiple choice questions (with answers) about philosophy.\n\nQ: The study of reality in the broadest sense, an inquiry into the elemental nature of the universe and the things in it, is known as _____.\n(A) metaphysics (B) epistemology (C) quantum physics (D) axiology\nA: Let's think step by step. We refer to Wikipedia articles on philosophy for help. Among the options, only metaphysics studies the nature of reality and existence. The answer is (A).\n\nQ: According to Moore\u2019s \u201cideal utilitarianism,\u201d the right action is the one that brings about the greatest amount of:\n(A) pleasure. (B) happiness. (C) good. (D) virtue.\nA: Let's think step by step. We refer to Wikipedia articles on philosophy for help. Moore's \"ideal utilitarianism\" states that one's actions should maximize intrinsic goods. The answer is (C).\n\nQ: Before Tolstoy's Christian conversion, what was his perspective on the meaning of life?\n(A) optimist (B) satisfied (C) nominally religious (D) pessimist\nA: Let's think step by step. We refer to Wikipedia articles on philosophy for help. Before his conversion, Tolstoy feels that life was uncertain, which is a pessimist's point of view. The answer is (D).\n\nQ: According to d'Holbach, people always act according to _____.\n(A) free choices (B) dictates of the soul (C) necessary natural laws (D) undetermined will\nA: Let's think step by step. We refer to Wikipedia articles on philosophy for help. d'Holbach believes that people act according to necessary laws, and it proves nothing about people's free will. The answer is (C).\n\nQ: Psychological egoism is:\n(A) an ethical theory about how we ought to behave. (B) a generalization concerning the way people tend to behave. (C) a claim about human nature and the ways people are capable of behaving. (D) none of the above.\nA: Let's think step by step. We refer to Wikipedia articles on philosophy for help. Psychological egoism suggests that one behaves based on what makes one feels good, hence it is a claim about human nature and how humans are capable of behaving. The answer is (C).\n\n", "prehistory": "The following are multiple choice questions (with answers) about prehistory.\n\nQ: What is the approximate mean cranial capacity of Homo erectus?\n(A) under 650 cc (B) about 800 cc (C) just under 1000 cc (D) 1200 cc\nA: Let's think step by step. We refer to Wikipedia articles on prehistory for help. The average cranium capacity of Homo erectus is less than 1000 cubic cm. The answer is (C).\n\nQ: According to Timothy Pauketat, the evidence for social stratification and political power at Cahokia suggests:\n(A) a center of Mississippian civilization with conditions similar to the rise of early states. (B) the limitations of authority in a Native American society of egalitarian foragers. (C) a simple chiefdom or perhaps a complex chiefdom had evolved by A.D. 1500. (D) a center of Mississippian civilization with conditions similar to societies on the Northwest Coast of North America.\nA: Let's think step by step. We refer to Wikipedia articles on prehistory for help. Timothy Pauketat is known for his research on Cahokia, the center of the Mississippian culture, where he found similar conditions to the rise of early states. The answer is (A).\n\nQ: Recent research on hominid species dating from the Middle Pliocene indicates there was (as of 2020):\n(A) a great amount of species diversity, or a single species that exhibited a lot of diversity. (B) very little species diversity during this period and very few hominids. (C) decreased species diversity due to a prolonged ice age followed by a severe drought. (D) decreased species diversity but increased numbers of hammerstones and flakes, indicating stone tool manufacture.\nA: Let's think step by step. We refer to Wikipedia articles on prehistory for help. Recent research has recognized multiple hominid species from the Middle Pliocene, meaning that there is a great amount of species diversity or diversity in a single species. The answer is (A).\n\nQ: Researchers now believe that the decline of the Maya was caused chiefly by:\n(A) a cataclysm of some kind, such as an earthquake, volcano, or tsunami. (B) ecological degradation resulting from slash-and-burn farming techniques. (C) endless wars between neighboring Mayan city-states. (D) practices of interbreeding that led to a steep rise in congenital disorders.\nA: Let's think step by step. We refer to Wikipedia articles on prehistory for help. Researchers believe that the Maya collapse was mainly caused by over-exploitation of natural resources like the slash-and-burn farming techniques. The answer is (B).\n\nQ: The great Mayan king Pacal built temples in the city of Palenque in order to:\n(A) satisfy the powerful Mayan astronomer priests. (B) display his generosity to the common people, since they were allowed to live in the temples. (C) frighten away enemies, in particular the Spaniards. (D) legitimize his kingship, since his father was not royal.\nA: Let's think step by step. We refer to Wikipedia articles on prehistory for help. Pacal built the temples as the funerary monument to legitimize his kingship. The answer is (D).\n\n", "professional_accounting": "The following are multiple choice questions (with answers) about professional accounting.\n\nQ: An auditor traces the serial numbers on equipment to a nonissuer\u2019s subledger. Which of the following management assertions is supported by this test?\n(A) Valuation and allocation (B) Completeness (C) Rights and obligations (D) Presentation and disclosure\nA: Let's think step by step. We refer to Wikipedia articles on accounting for help. The completeness assertion is tested by tracing supporting documents to the record entries. The answer is (B).\n\nQ: One hundred years ago, your great-great-grandmother invested $100 at 5% yearly interest. What is the investment worth today?\n(A) $13,000 (B) $600 (C) $15,000 (D) $28,000\nA: Let's think step by step. We refer to Wikipedia articles on accounting for help. A $100 investment at 5% yearly interest is worth 100*(1.05)^100=13150 after 100 years, which is around $13,000. The answer is (A).\n\nQ: On January 1, year 1, Alpha Co. signed an annual maintenance agreement with a software provider for $15,000 and the maintenance period begins on March 1, year 1. Alpha also incurred $5,000 of costs on January 1, year 1, related to software modification requests that will increase the functionality of the software. Alpha depreciates and amortizes its computer and software assets over five years using the straight-line method. What amount is the total expense that Alpha should recognize related to the maintenance agreement and the software modifications for the year ended December 31, year 1?\n(A) $5,000 (B) $13,500 (C) $16,000 (D) $20,000\nA: Let's think step by step. We refer to Wikipedia articles on accounting for help. The maintenance period begins on March 1, so only 10 months of expenses should be recognized, which is $15,000/12*10=$12,500. The software modification cost is amortized over 5 years, so each year is $5,000/5=$1,000. So the total expense is $12,500+$1,000=$13,500. The answer is (B).\n\nQ: Krete is an unmarried taxpayer with income exclusively from wages. By December 31, year 1, Krete's employer has withheld $16,000 in federal income taxes and Krete has made no estimated tax payments. On April 15, year 2, Krete timely filed for an extension request to file her individual tax return, and paid $300 of additional taxes. Krete's year 1 tax liability was $16,500 when she timely filed her return on April 30, year 2, and paid the remaining tax liability balance. What amount would be subject to the penalty for underpayment of estimated taxes?\n(A) $0 (B) $500 (C) $1,650 (D) $16,500\nA: Let's think step by step. We refer to Wikipedia articles on accounting for help. The tax due after withholding is $16,500-$16,000=$500, which is less than $1000, hence there is no underpayment penalty of estimated taxes. The answer is (A).\n\nQ: Box a nongovernmental not-for-profit organization had the following transactions during the year: Proceeds from sale of investments $80000 Purchase of property plant and equipment $10000 Proceeds from long-term debt $100000 Loss on sale of investment $5000 What amount should be reported as net cash provided by financing activities in Box's statement of cash flows?\n(A) $70,000 (B) $75,000 (C) $80,000 (D) 100000\nA: Let's think step by step. We refer to Wikipedia articles on accounting for help. Among the four transactions, only Proceeds from long-term debt belongs to the financing activities section of cashflow, hence the amount reported should be $100000. The answer is (D).\n\n", "professional_law": "The following are multiple choice questions (with answers) about professional law.\n\nQ: A son owed a creditor $5,000. The son's father contacted the creditor and told him that he wanted to pay the son's debt. The father signed a document that stated the father would pay the son's debt at a rate of $500 a month for 10 months. The creditor made no written or oral commitment to forbear to sue the son to collect the $5,000 debt, and the father made no oral or written request for any such forbearance. For the next five months, the father made and the creditor accepted the $500 monthly payments as agreed. During that period, the creditor, in fact, did forbear to take any legal action against the son. However, the father then informed the creditor that he would make no further payments on the debt. Which of the following is the most persuasive argument that the father is liable to the creditor under the terms of their agreement?\n(A) The father's promise and the creditor's reliance thereon, if proved, gave rise to a valid claim by the creditor against the father based on the doctrine of promissory estoppel. (B) Because it was foreseeable that the father's promise would induce the creditor to forbear taking any action against the son, such forbearance was, as a matter of law, a bargained-for consideration for the father's promise. (C) The father's five payments to the creditor totaling $2,500 manifested a serious intent on the father's part to be contractually bound, and such manifestation is generally recognized as an effective substitute for consideration. (D) By assuming the antecedent debt obligation that the son owed to the creditor, the father became a surety whose promise to the creditor was enforceable, since it was in writing and supported by adequate consideration. \nA: Let's think step by step. We refer to Wikipedia articles on law for help. The doctrine of promissory estoppel stops a person from going back on a promise in contract law, hence option (A) should be the most persuasive argument. The answer is (A).\n\nQ: A state has recently enacted a statute prohibiting the disposal of any nuclear wastes within the state. This law does not contravene or conflict with any federal statutes. A man operates a company in the state that is engaged in the disposal of nuclear wastes. Subsequent to the passage of the state statute, the man, not yet aware of the new law, entered into contracts with many out-of-state firms to dispose of their nuclear wastes in the state. On account of this new law, however, the man will be unable to perform these contracts. Assume that the man has standing to challenge this state law. Which of the following presents his strongest constitutional grounds to challenge the state law prohibiting the disposal of nuclear wastes within the state?\n(A) The commerce clause. (B) The equal protection clause of the Fourteenth Amendment. (C) The privileges and immunities clause of Article IV, Section 2. (D) The contract clause.\nA: Let's think step by step. We refer to Wikipedia articles on law for help. The commerce clause states that Congress shall have the power to regulate commerce with foreign Nations, and among the several States, and with the Indian Tribes. The statute affects inter-state commerce which puts it into question. Hence the man's strongest argument should be the commerce clause. The answer is (A).\n\nQ: On October 1, 1980, a developer, owner of several hundred acres in a rural county, drafted a general development plan for the area. The duly recorded plan imposed elaborate limitations and restrictions upon the land in the plan, which was to be developed as a residential district. The restrictions were to extend to all persons acquiring any of the lots and to their heirs, assigns, and lessees. It was further provided that all subsequent owners would be charged with due notice of the restrictions. Among those restrictions in the general plan were the following:(22) A franchise right is created in a strip of land 10 feet in width along the rear of each lot for the use of public utility companies with right of ingress and egress. (23) No house or structure of any kind shall be built on the aforementioned strip of land running through the said blocks. In 2000, a retiree purchased one of the lots, built a house, and erected a fence in the rear of his property within the restricted area. In 2004, a teacher purchased a lot adjacent to the retiree's property and built a new house. Two years later, a librarian purchased the lot that adjoined the teacher's property. The three deeds to those properties each contained references to the deed book where the general plan was recorded. In 2008, the librarian began the construction of a seven-foot post-and-rail fence along the line dividing his lot with the teacher's, and along the center of the area subject to the franchise right. Although the teacher objected to its construction, the fence was completed. If the teacher seeks a mandatory injunction to compel removal of the librarian's fence, the court will most likely\n(A) grant relief, because the fence was in violation of the easement restriction. (B) grant relief, because the encroachment of the fence violated the restriction in the original plan. (C) deny relief, because the teacher failed to enforce the restriction against the retiree. (D) deny relief, because the fence would not be construed as \"a structure\" within the terms of the restriction. \nA: Let's think step by step. We refer to Wikipedia articles on law for help. The restrictions in the original plan say no house or structure of any kind shall be built on the aforementioned strip of land running through the said blocks. Hence the court will most likely grant relief because the fence violated the restriction in the original plan. The answer is (B).\n\nQ: Judge took judicial notice of some facts at the beginning of the trial. Which of the following is not an appropriate kind of fact for judicial notice?\n(A) Indisputable facts. (B) Facts that have been asserted by individual political organizations. (C) Facts recognized to be true by common knowledge. (D) Facts capable of scientific verification.\nA: Let's think step by step. We refer to Wikipedia articles on law for help. Among the options, facts that have been asserted by individual political organizations is not an appropriate kind of fact for judicial notice. The answer is (B).\n\nQ: A state legislature has recently enacted a statute making it a misdemeanor to curse or revile or use obscene or opprobrious language toward or in reference to a police officer perfonning his duties. A student at a state university organized a demonstration on campus to protest the war. The rally was attended by a group of 50 students who shouted anti-war messages at cars passing by. To show his contempt for the United States, the student sewed the American flag to the rear of his jeans. When a police officer saw the flag sown on the student's jeans, he approached and told him to remove the flag or he would be placed under arrest. The student became angered and shouted at the police officer, \"Listen, you bastard, I'll wear this rag anywhere I please. \" The student was subsequently placed under arrest and charged with violating the state statute. The student subsequently brings suit in state court challenging the constitutionality of the statute. The strongest constitutional argument for the student is that\n(A) the statute is void for vagueness under the Fourteenth Amendment's due process clause. (B) the statute is invalid because it violates the petitioner's freedom of speech under the First Amendment. (C) the statute is an abridgment of freedom of speech under the First Amendment because less restrictive means are available for achieving the same purpose. (D) the statute is overbroad and consequently invalid under the First and Fourteenth Amendments.\nA: Let's think step by step. We refer to Wikipedia articles on law for help. The Fourteenth Amendment further supports the First Amendment by establishing a due process clause. Hence the strongest argument should be the statute is overbroad and consequently invalid under the First and Fourteenth Amendments. The answer is (D).\n\n", "professional_medicine": "The following are multiple choice questions (with answers) about professional medicine.\n\nQ: A 22-year-old male marathon runner presents to the office with the complaint of right-sided rib pain when he runs long distances. Physical examination reveals normal heart and lung findings and an exhalation dysfunction at ribs\u00a04-5 on the right. Which of the following muscles or muscle groups will be most useful in correcting this dysfunction utilizing a direct method?\n(A) anterior scalene (B) latissimus dorsi (C) pectoralis minor (D) quadratus lumborum\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. Among the options, only pectoralis minor muscle origins from the outer surfaces of the 3rd to 5th ribs. The answer is (C).\n\nQ: A 36-year-old male presents to the office with a\u00a03-week\u00a0history of low back pain. He denies any recent trauma but says that he climbs in and out of his truck numerous times a day for his job. Examination of the patient in the prone position reveals a deep sacral sulcus on the left, a posterior inferior lateral angle on the right, and a lumbosacral junction that springs freely on compression. The most likely diagnosis is\n(A) left-on-left sacral torsion (B) left-on-right sacral torsion (C) right unilateral sacral flexion (D) right-on-right sacral torsion\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. The deep sulcus on the left, a posterior ILA on the right, with a negative spring test suggests a right-on-right sacral torsion. All other options have a deep sulcus on the right. The answer is (D).\n\nQ: A 44-year-old man comes to the office because of a 3-day history of sore throat, nonproductive cough, runny nose, and frontal headache. He says the headache is worse in the morning and ibuprofen does provide some relief. He has not had shortness of breath. Medical history is unremarkable. He takes no medications other than the ibuprofen for pain. Vital signs are temperature 37.4\u00b0C (99.4\u00b0F), pulse 88/min, respirations 18/min, and blood pressure 120/84 mm Hg. Examination of the nares shows erythematous mucous membranes. Examination of the throat shows erythema and follicular lymphoid hyperplasia on the posterior oropharynx. There is no palpable cervical adenopathy. Lungs are clear to auscultation. Which of the following is the most likely cause of this patient's symptoms?\n(A) Allergic rhinitis (B) Epstein-Barr virus (C) Mycoplasma pneumonia (D) Rhinovirus\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. The symptoms, especially the headache, suggest that the most likely cause is Rhinovirus. Epstein-Barr virus will cause swollen lymph nodes but there is no palpable cervical adenopathy. Lungs are clear to auscultation suggests it's not Mycoplasma pneumonia. The answer is (D).\n\nQ: A previously healthy 32-year-old woman comes to the physician 8 months after her husband was killed in a car crash. Since that time, she has had a decreased appetite and difficulty falling asleep. She states that she is often sad and cries frequently. She has been rechecking the door lock five times before leaving her house and has to count exactly five pieces of toilet paper before she uses it. She says that she has always been a perfectionist but these urges and rituals are new. Pharmacotherapy should be targeted to which of the following neurotransmitters?\n(A) Dopamine (B) Glutamate (C) Norepinephrine (D) Serotonin\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. The patient feels sad and among the options, only Dopamine and Serotonin can help increase positive emotions. Serotonin also affects digestion and metabolism, which can help the patient's decreased appetite and sleep difficulty. The answer is (D).\n\nQ: A 42-year-old man comes to the office for preoperative evaluation prior to undergoing adrenalectomy scheduled in 2 weeks. One month ago, he received care in the emergency department for pain over his right flank following a motor vehicle collision. At that time, blood pressure was 160/100 mm Hg and CT scan of the abdomen showed an incidental 10-cm left adrenal mass. Results of laboratory studies, including complete blood count, serum electrolyte concentrations, and liver function tests, were within the reference ranges. The patient otherwise had been healthy and had never been told that he had elevated blood pressure. He takes no medications. A follow-up visit in the office 2 weeks ago disclosed elevated urinary normetanephrine and metanephrine and plasma aldosterone concentrations. The patient was referred to a surgeon, who recommended the adrenalectomy. Today, vital signs are temperature 36.6\u00b0C (97.9\u00b0F), pulse 100/min, respirations 14/min, and blood pressure 170/95 mm Hg. Physical examination discloses no significant findings. Initial preoperative preparation should include treatment with which of the following?\n(A) Labetalol (B) A loading dose of potassium chloride (C) Nifedipine (D) Phenoxybenzamine\nA: Let's think step by step. We refer to Wikipedia articles on medicine for help. The symptoms and the adrenal mass suggested pheochromocytoma, and the blood pressure indicates hypertension. Phenoxybenzamine is used to treat hypertension caused by pheochromocytoma. The answer is (D).\n\n", "professional_psychology": "The following are multiple choice questions (with answers) about professional psychology.\n\nQ: In the construction of a multiple regression equation for purposes of prediction, the optimal combination of measures is one in which the predictors\n(A) are uncorrelated with each other but are moderately correlated with the criterion (B) have low correlations with each other and low correlations with the criterion (C) are highly intercorrelated with each other and moderately correlated with the criterion (D) have low correlations with the criterion bur are moderately correlated with each other\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. The basis of multiple regression is to assess the relationship between one continuous variable and a set of independent variables. So the predictors should be uncorrelated with each other but are moderately correlated with the criterion. The answer is (A).\n\nQ: There are three ways to measure the Central Tendency: the Mean, the Median and the Mode. From your knowledge about them, what is the mode?\n(A) less sensitive to extreme scores than the mean (B) more useful for skewed distributions (C) sensitive to extreme values and highly skewed distributions (D) the most frequently occurring number\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. The definition of mode is the most frequently occurring number. The answer is (D).\n\nQ: Carl Jung believed that a client's transference:\n(A) is a fantasy that distracts the client from reality. (B) represents \u201cmixed feelings\u201d toward the therapist. (C) \"is a form of \"\"acting out.\"\"\" (D) reflects the client\u2019s personal and collective unconscious.\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. Transference is a phenomenon that a person's feelings are unconsciously redirected, so it reflects the client's personal and collective unconscious. The answer is (D).\n\nQ: In terms of Hofstede\u2019s (1980) five cultural dimensions, the United States scores at the top of the scale on:\n(A) individualism. (B) individualism and power distance. (C) power distance and masculinity. (D) uncertainty avoidance.\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. US scores highest on individualism among the five cultural dimensions. The answer is (A).\n\nQ: One of your therapy clients asks your advice about a good weight- reduction program. You have investigated the programs in the community and are enrolled in the one you consider the best. This program offers a $50 bonus to its patrons for each new person they bring into the program. Under these circumstances, your most appropriate response would be to\n(A) tell your client the pros and cons of each program you know about except for the one in which you are enrolled (B) recommend to your client the program in which you are enrolled and explain the $50 bonus you will receive (C) recommend to your client the program in which you are enrolled and offer to have the $50 bonus credited to your client's account in the program (D) tell your client the pros and cons of each program you know about, but do not claim the $50 bonus if your client enrolls in your program\nA: Let's think step by step. We refer to Wikipedia articles on psychology for help. Based on the circumstances, you should tell your client about the pros and cons of each program, but it would be inappropriate to receive the bonus, so you should not claim the $50 bonus. The answer is (D).\n\n", "public_relations": "The following are multiple choice questions (with answers) about public relations.\n\nQ: Earth Hour was a campaign launched by which organization?\n(A) Greenpeace (B) The UN (C) Oxfam (D) World Wildlife Fund\nA: Let's think step by step. We refer to Wikipedia articles on public relations for help. Earth Hour is a worldwide movement oragnized launched by the World Wildlife Fund. The answer is (D).\n\nQ: In issues management, what is the most proactive approach to addressing negative or misleading information posted online about your organization?\n(A) Buy domain names that could be used by opposition groups. (B) Post anonymous comments on blogs to combat this information. (C) Prepare a news release that discredits the inaccurate information. (D) Make policy changes to address complaints highlighted on these sites.\nA: Let's think step by step. We refer to Wikipedia articles on public relations for help. In issues management, the most proactive approach to addressing negative or misleading information posted online is to make policy changes to address complaints highlighted on those sites. The answer is (D).\n\nQ: At which stage in the planning process would a situation analysis be carried out?\n(A) Defining the program (B) Planning the program (C) Taking action and implementing ideas (D) Evaluation of the program\nA: Let's think step by step. We refer to Wikipedia articles on public relations for help. Situation analyses are typically carried out during the planning process stage of defining the program. The answer is (A).\n\nQ: Which of these statements is true of the Vatican in 2010 at the time of the accusations of child abuse cover-ups?\n(A) There was a coordinated media response. (B) Consistent messages were communicated. (C) Criticisms were taken as attacks on the Catholic Church. (D) The credibility of the Vatican was upheld.\nA: Let's think step by step. We refer to Wikipedia articles on public relations for help. In 2010 when there were accusations of child abuse cover-ups, the Vatican took those criticisms as attacks on the Catholic Church. The answer is (C).\n\nQ: What should a public relations media practitioner do if she does not know the answer to a reporter's question?\n(A) Give the reporter other information she is certain is correct. (B) Say that the information is 'off the record' and will be disseminated later. (C) Say 'I don't know' and promise to provide the information later. (D) Say 'no comment,' rather than appear uninformed.\nA: Let's think step by step. We refer to Wikipedia articles on public relations for help. If a public relations media practitioner does not know the answer to a reporter's question, they should say 'I don't know' and offer to provide the information later. The answer is (C).\n\n", "security_studies": "The following are multiple choice questions (with answers) about security studies.\n\nQ: What are the frameworks of analysis within which terrorism has been considered (as of 2020)?\n(A) Competition between larger nations has resulted in some countries actively supporting terrorist groups to undermine the strength of rival states. Terrorist networks are extended patronage clubs maintained and paid for by their donor states and are conceptualised as being like state actors, to be dealt with using military force. (B) Globalization has enabled the internationalization of terrorist activities by opening up their operational space, although coordination is still managed from a geographical base. This suggests that terrorist groups are nationally structured which means that terrorism cannot be considered in terms of a war to be defeated militarily without having serious implications on the indigenous population. (C) Terrorism can be viewed as a problem to be resolved by military means (war on terrorism), by normal police techniques (terrorism as crime), or as a medical problem with underlying causes and symptoms (terrorism as disease). (D) Terrorism is viewed as a criminal problem. The criminalization of terrorism has two important implications. Firstly, it suggests that terrorism can be eradicated - terrorists can be caught and brought to trial by normal judicial proceedings thereby removing the threat from society - and secondly, it suggests that preventative crime techniques are applicable to prevent its development.\nA: Let's think step by step. We refer to Wikipedia articles on security studies for help. (A) is wrong because it is not competition between larger nations that causes terrorism. \n(B) is wrong because globalization is not the cause of terrorism.\n(C) is correct because the US undertook the war on terrorism. \n(D) is wrong because preventative crime techniques will likely not end terrorism. The answer is (C).\n\nQ: Which of the following is the best lens through which to investigate the role of child soldiers?\n(A) Child soldiers are victims of combat that need re-education and rehabilitation. (B) Children and their mothers are not active subjects in warfare and are best considered as subjects in the private sphere. (C) Children are most often innocent bystanders in war and are best used as signifiers of peace. (D) Children have political subjecthood that is missed when they are considered as passive victims of warfare.\nA: Let's think step by step. We refer to Wikipedia articles on security studies for help. Child soliders as a political topic can be missed when they are considered passive victims of warfare. The answer is (D).\n\nQ: How can we best describe the relationship between the state-centric approach and the concept of human security?\n(A) There are such wide divisions within the human security framework regarding the nature of threats and referent objects that no widely applicable comparisons between state-centric approaches and human security can be drawn. (B) By adopting the framework of human security, the limitations of the realist state-centric approach become evident. Whilst human security defines the referent object as the person or population, state-centric approaches prioritise the security of the state, de-prioritizing the pursuit of human security. (C) The state-centric approach to security is a faction of human security, usually defined within the broad school of human security. By being state-centric this approach prioritises the individual as the referent object in security studies. (D) Both the state-centric and human-centric approaches to security are mutually exclusive and offer a sufficient analytic framework with which to understand the international security system. It is therefore the role of security analysts to determine which of these substantial concepts is correct, and which should be discarded.\nA: Let's think step by step. We refer to Wikipedia articles on security studies for help. Human security focuses on a person or population whereas state-centric approaches focus on the state while deprioritizing human security. The answer is (B).\n\nQ: In order to become securitized, a threat must be presented in which of these ways?\n(A) As an existential threat that requires immediate and extraordinary action, posing a threat to the survival of the state or to societal security. (B) As requiring immediate and extraordinary action by the state, threatening the survival of a referent object and therefore warranting the use of measures not normally employed in the political realm. (C) As an urgent threat to the survival of the referent object, so serious that it legitimises the employment of extraordinary action in response. (D) As an urgent threat to the survival of the audience that requires extraordinary or emergency measures.\nA: Let's think step by step. We refer to Wikipedia articles on security studies for help. To be securitized, a threat must be an urgent threat to the survival of the referent object. The answer is (C).\n\nQ: What distinguishes coercive diplomacy from military force?\n(A) Compellence is another term for coercive diplomacy, but covering a narrower set of criteria; compellence covers those threats aimed at initiating adversary action. A threat to coerce a state to give up part of its territory would count as coercive diplomacy, as long as that threat proactively initiates action before reactive diplomacy is taken. (B) Coercive diplomacy constitutes the threats of limited force to induce adversary's incentive to comply with the coercer's demands. It is an influence strategy that is intended to obtain compliance: the use of force to defeat an opponent first does not count. It leaves an element of choice with the target to comply, or to continue. (C) Military force, or the threat of military force, utilises fear to achieve strategic objectives. Coercive diplomacy is differentiated from this approach, because it does not use fear as a tool for coercing an adversary. (D) Coercive diplomacy is employed to use force but to limit its effects on the international community. Coercive diplomacy is an aggressive strategy that is intended to obtain compliance through defeat. It does not leave an element of choice with the target, the target either being forced to comply or engage in conflict. It seeks to control by imposing compliance by removing any opportunity for negotiation or concession.\nA: Let's think step by step. We refer to Wikipedia articles on security studies for help. Coercive diplomacy uses the threat of force to induce the opponent to comply with demands. The answer is (B).\n\n", "sociology": "The following are multiple choice questions (with answers) about sociology.\n\nQ: Which of the following is not a problem associated with official statistics on strike action?\n(A) most strikes go unnoticed by employers and the mass media (B) not all industrial disputes will be reported by the employer (C) the definition of strikes excludes those that involve fewer than ten workers or last less than one day (D) it is hard to compare strikes that were measured in different ways\nA: Let's think step by step. We refer to Wikipedia articles on sociology for help. Official statistics on strike action can be problematic because not all industrial disputes will be reported by employers, the definition of strikes excludes those that involves fewer than ten workers or last less than one day, and it is hard to compare strikes that were measured in different ways. Thus, (A) is not a problem associated with official statistics on strike action. The answer is (A).\n\nQ: What does Berger (1963) describe as a metaphor for social reality?\n(A) a fairground ride (B) a circus (C) a puppet theatre (D) a ballet\nA: Let's think step by step. We refer to Wikipedia articles on sociology for help. Berger describes social reality using the metaphor of a puppet theatre. The answer is (C).\n\nQ: The term 'hegemony' refers to:\n(A) the tendency for the working class not to realize their own interests (B) a dominant ideology that legitimates economic, political and cultural power (C) a form of dual consciousness based on ideology and everyday experiences (D) a mode of payment given for outstanding topiary\nA: Let's think step by step. We refer to Wikipedia articles on sociology for help. Hegemony refers to a dominant ideology that legitimates economic, policital, and cultural power. The answer is (B).\n\nQ: The shift from 'civil religion' to 'common religion' means that:\n(A) the increasing bureaucracy of the state has made religion only a marginal part of our lives (B) despite the weakening of traditional authority, our everyday lives and 'common sense' remain shaped by religious beliefs and values (C) religious participation in collective worship may have declined, but people still practise their faiths in private (D) people are much more likely to discuss their religious beliefs in public, informal settings\nA: Let's think step by step. We refer to Wikipedia articles on sociology for help. The shift from civil religion to common religion means that despite the weakening of traditional authority, our everyday lives and common sense remain shaped by religious beliefs and values. The answer is (B).\n\nQ: Which of the following did the post-war welfare state of 1948 not aim to provide:\n(A) free health care and education for all (B) a minimum wage (C) full employment (D) universal welfare\nA: Let's think step by step. We refer to Wikipedia articles on sociology for help. The post-war welfare state of 1948 aimed to provide free healthcare and education, full employment, and universal welfare. But it did not aim to provide a minimum wage. The answer is (B).\n\n", "us_foreign_policy": "The following are multiple choice questions (with answers) about us foreign policy.\n\nQ: How did Donald Trump attack globalization in the 2016 campaign?\n(A) Globalization had made men like him too rich (B) Globalization only benefited certain American states, such as New York (C) Liberal elites had encouraged globalization, while 'ordinary Americans' lost jobs because of it (D) Globalization encouraged damaging trade wars\nA: Let's think step by step. We refer to Wikipedia articles on us foreign policy for help. Trump attacked globalization because he believed ordinary Americans lost jobs due to it, and so he wanted to blame liberals who had encouraged it. The answer is (C).\n\nQ: How did NSC-68 change U.S. strategy?\n(A) It globalized containment. (B) It militarized containment. (C) It called for the development of the hydrogen bomb. (D) All of the above\nA: Let's think step by step. We refer to Wikipedia articles on us foreign policy for help. NSC-68 outlined a variety of courses of action, including globalization of containment, militarization of contaiment, and the development of the hydrogen bomb. The answer is (D).\n\nQ: How do Defensive Realism and Offensive Realism differ in their explanation of state behaviour?\n(A) Defensive realists place greater emphasis on the role of international institutions (B) Defensive realists place less emphasis on geographical factors (C) Offensive realists give more priority to the national interest than Defensive realists. (D) Defensive realists believe states are security maximizers, while Offensive realists believe states to be power maximizers\nA: Let's think step by step. We refer to Wikipedia articles on us foreign policy for help. While defensive realism advocates that states are security maximizers, offensive realists think of states as power maximizers. The answer is (D).\n\nQ: The realm of policy decisions concerned primarily with relations between the United States and the rest of the world is known as\n(A) terrorism policy. (B) economic policy. (C) foreign policy. (D) international policy.\nA: Let's think step by step. We refer to Wikipedia articles on us foreign policy for help. The topic of policy decisions concerns with relations between the US and the rest of the world is known as foreign policy. The answer is (C).\n\nQ: How did the 2008 financial crisis affect America's international reputation?\n(A) It damaged support for the US model of political economy and capitalism (B) It created anger at the United States for exaggerating the crisis (C) It increased support for American global leadership under President Obama (D) It reduced global use of the US dollar\nA: Let's think step by step. We refer to Wikipedia articles on us foreign policy for help. The 2008 financial crisis damanged the international reputation of the American model of political economy and capitalism. The answer is (A).\n\n", "virology": "The following are multiple choice questions (with answers) about virology.\n\nQ: The median survival time to AIDS and death was established by following:\n(A) Seroprevalent HIV-infected individuals (B) Seronegatives (C) Seroconverters (D) High-risk seronegatives\nA: Let's think step by step. We refer to Wikipedia articles on virology for help. The median survival time to AIDS and death was established as a result of the development of seroconverters. The answer is (C).\n\nQ: Which of the following is a morphological characteristic of the paramyxoviruses.\n(A) Fragile viruses often visualised with RNA spewing from the inside (B) Elongate viruses (C) Icosahedral viruses with envelope (D) Very large viruses\nA: Let's think step by step. We refer to Wikipedia articles on virology for help. Paramyxoviruses are fragile viruses often visualised with RNA spewing from the inside. The answer is (A).\n\nQ: The most important goal of a behavioral intervention is:\n(A) Change in behavior (B) Comprehensive coverage (C) Effective use of behavioral theory (D) Sustained behavior change\nA: Let's think step by step. We refer to Wikipedia articles on virology for help. The prim goal of a behavioral intervention is to cause sustained behavior change. The answer is (D).\n\nQ: A key factor facilitating the application of nested case-control studies from the MACS was:\n(A) Data collection (B) Establishment of a repository of biologic specimens (C) Participant interest (D) Administration of the questionnaire by staff\nA: Let's think step by step. We refer to Wikipedia articles on virology for help. The Multicenter AIDS Cohort Study's use of nested case-control studies was facilitated by the establishment of a repository of biologic specimens. The answer is (B).\n\nQ: Why are parvoviruses a highly impactful parasite?\n(A) Because they have no nucleic acid (B) They require a helper virus (C) Only replicate in dividing cells (D) Can integrate into host chromosomes\nA: Let's think step by step. We refer to Wikipedia articles on virology for help. Paroviruses are highly impactful because they do not have nucleic acid. The answer is (A).\n\n", "world_religions": "The following are multiple choice questions (with answers) about world religions.\n\nQ: How can the Upanishads be characterized?\n(A) Ritual texts (B) Philosophical texts (C) Hymns (D) Origin stories\nA: Let's think step by step. We refer to Wikipedia articles on world religions for help. The Upanishads are the most recent part of Vedas (the oldest scriptures in Hinduism) and supplied the basis of later Hindu philosophy. So they are philosophical texts. The answer is (B).\n\nQ: What is the Second Gem in Buddhism?\n(A) The Dharma (B) The Sangha (C) The Buddha (D) The Bodhisattva\nA: Let's think step by step. We refer to Wikipedia articles on world religions for help. The Second Gem in Buddhism is The Dharma. The answer is (A).\n\nQ: Which Japanese government promoted a kind of national cult based on the emperor and his associations with kami?\n(A) Honen (B) Tanaka (C) Tokugawa (D) Meiji\nA: Let's think step by step. We refer to Wikipedia articles on world religions for help. The promotion of a national cult based on the emperor and his associations with Kami happened during the reign of Emperor Meiji (1852-1912). The answer is (D).\n\nQ: In which dynasty was the \"Mandate of Heaven\" developed to legitimatize the new rulers?\n(A) Shang (B) Zhou (C) Han (D) Xia\nA: Let's think step by step. We refer to Wikipedia articles on world religions for help. The \"Mandate of Heaven\" was developed as an ancient Chinese philosophical concept during the Zhou Dynasty (1046-256 BCE). The answer is (B).\n\nQ: What is the sign of the covenant for Jewish males?\n(A) The rainbow (B) Circumcision (C) A son (D) Bar mitzvah\nA: Let's think step by step. We refer to Wikipedia articles on world religions for help. In Judaism, the most distinctive sign of the covenant is circumcision (brit milah). The answer is (B).\n\n"} diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_mmlu.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6c1c6a19dc7638adfa630ce80b58294f5b351b8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_mmlu.yaml @@ -0,0 +1,34 @@ +group: mmlu_flan_cot_fewshot +group_alias: mmlu (flan style, fewshot cot) +task: + - group: stem + task: + - mmlu_flan_cot_fewshot_stem + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: other + task: + - mmlu_flan_cot_fewshot_other + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: social sciences + task: + - mmlu_flan_cot_fewshot_social_sciences + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: humanities + task: + - mmlu_flan_cot_fewshot_humanities + aggregate_metric_list: + - metric: acc + weight_by_size: True +aggregate_metric_list: + - aggregation: mean + metric: exact_match + weight_by_size: True + filter_list: get-answer +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_mmlu_flan_cot_fewshot_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_mmlu_flan_cot_fewshot_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..cfbf222e5ba5892a5e26113d382ea86ec1300ce0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/_mmlu_flan_cot_fewshot_template_yaml @@ -0,0 +1,30 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +validation_split: validation +test_split: test +fewshot_config: + sampler: first_n +output_type: generate_until +doc_to_text: "{% if choices is defined%}Q: {{question.strip()}}\n(A) {{choices[0]}} (B) {{choices[1]}} (C) {{choices[2]}} (D) {{choices[3]}}\nA: Let's think step by step.{% else %}Q: {{ question.strip() }}\nA:{% endif %}" +doc_to_target: "{{['(A)', '(B)', '(C)', '(D)'][answer] if answer is defined else target}}" +filter_list: + - name: "get-answer" + filter: + - function: "regex" + regex_pattern: "(?<=answer is )(.*)(?=.)" + - function: "take_first" +generation_kwargs: + until: + - "" + do_sample: false + temperature: 0.0 +num_fewshot: 4 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 2.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6235d5c0997558a123258cba3dfdb4b844a2fb60 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_abstract_algebra.yaml @@ -0,0 +1,59 @@ +dataset_name: abstract_algebra +description: The following are multiple choice questions (with answers) about abstract + algebra. +fewshot_config: + sampler: first_n + samples: + - question: 'Statement 1 | Every element of a group generates a cyclic subgroup of + the group. Statement 2 | The symmetric group S_10 has 10 elements. + + (A) True, True (B) False, False (C) True, False (D) False, True' + target: Let's think step by step. A cyclic group is a group that is generated + by a single element. Hence a subgroup generated by a single element of a group + is cyclic and Statement 1 is True. The answer is (C). + - question: 'The symmetric group $S_n$ has $ + + actorial{n}$ elements, hence it is not true that $S_{10}$ has 10 elements. + + Find the characteristic of the ring 2Z. + + (A) 0 (B) 3 (C) 12 (D) 30' + target: Let's think step by step. A characteristic of a ring is R is $n$ if the + statement $ka = 0$ for all $a\in 2Z$ implies that $k$ is a multiple of $n$. + Assume that $ka = 0$ for all $a\in 2Z$ for some $k$. In particular $2k = 0$. + Hence $k=0$ and $n=0$. The answer is (A). + - question: 'Statement 1| Every function from a finite set onto itself must be one + to one. Statement 2 | Every subgroup of an abelian group is abelian. + + (A) True, True (B) False, False (C) True, False (D) False, True' + target: "Let's think step by step. Statement 1 is true. Let $S$ be a finite set.\ + \ If $f:S \nightarrow S$ is a onto function, then $|S| = |f(S)|$. If $f$ was\ + \ not one to one, then for finite domain $S$ the image would have less than\ + \ $S$ elements, a contradiction.\nStatement 2 is true. Let $G$ be an abelian\ + \ group and $H$ be a subgroup of $G$. We need to show that $H$ is abelian. Let\ + \ $a,b \\in H$. Then $a,b \\in G$ and $ab=ba$. Since $G$ is abelian, $ab=ba$.\ + \ Since $H$ is a subgroup of $G$, $ab \\in H$. Therefore, $ab=ba$ and $H$ is\ + \ abelian. The answer is (A)." + - question: 'Statement 1 | If aH is an element of a factor group, then |aH| divides + |a|. Statement 2 | If H and K are subgroups of G then HK is a subgroup of G. + + (A) True, True (B) False, False (C) True, False (D) False, True' + target: Let's think step by step. Statement 2 is false. Let $H$ be a subgroup + of $S_3$ generated by the cycle $(1,2)$ and $K$ be a subgroup of $S_3$ generated + by the cycle $(1,3)$. Both $H$ and $K$ have two elements, the generators and + the identity. However $HK$ contains cycles (1,2), (1,3) and (2,3,1), but the + inverse of (2,3,1) is (2,1,3) and it does not belong to HK, hence HK is not + a subgroup. The answer is (B). + - question: 'Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field. + + (A) 0 (B) 1 (C) 2 (D) 3' + target: 'Let''s think step by step. Z_3[x]/(x^2 + c) is a field if and only if + x^2 + c does not have roots in Z_3. That is x^2 + c != 0 for every x in Z_3. + If c = 0, then x^2 + c = x^2 has root 0. If c = 1 then x^2 + c = x^2 + 1 = 0 + + 1 for x = 0, 1 + 1 = 2 for x = 1 and 1 + 1 = 2 for x = 2, hence x^2 + 1 does + not have any roots. For c = 2 the polynomial x^2 + 2 has two roots at x = 1 + and x = 2. Hence Z_3[x]/(x^2 + c) is a field if and only if c = 1. The answer + is (B).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_abstract_algebra diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_anatomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6521bdebf5efa653c7bc798fa7c4ecb985e4166 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_anatomy.yaml @@ -0,0 +1,75 @@ +dataset_name: anatomy +description: The following are multiple choice questions (with answers) about anatomy. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following is the body cavity that contains the pituitary + gland? + + (A) Abdominal (B) Cranial (C) Pleural (D) Spinal' + target: "Let's think step by step. We refer to Wikipedia articles on anatomy for\ + \ help. Let\u2019s solve this problem step by step. The pituitary gland is the\ + \ major endocrine gland attached to the base of the brain, and it is contained\ + \ in the Cranial cavity. The answer is (B)." + - question: 'Which of these branches of the trigeminal nerve contain somatic motor + processes? + + (A) The supraorbital nerve (B) The infraorbital nerve (C) The mental nerve (D) + None of the above' + target: "Let's think step by step. We refer to Wikipedia articles on anatomy for\ + \ help. Let\u2019s solve this problem step by step. \nWe know the following:\ + \ (A) The supraorbital nerve (also known as the frontal nerve) is the largest\ + \ branch of the ophthalmic nerve and branch of ophthalmic division of the trigeminal\ + \ nerve. (B) The infraorbital nerve is a branch of the maxillary division of\ + \ the trigeminal nerve. (C) The mental nerve is a branch of the mandibular division\ + \ of the trigeminal nerve. Because all these nerves are purely sensory nerves\ + \ and do not contain any somatic motor processes. Therefore, the answer should\ + \ be none of the above, which is (D). The answer is (D)." + - question: 'In Angle''s Class II Div 2 occlusion there is + + (A) excess overbite of the upper lateral incisors. (B) negative overjet of the + upper central incisors. (C) excess overjet of the upper lateral incisors. (D) + excess overjet of the upper central incisors.' + target: "Let's think step by step. We refer to Wikipedia articles on anatomy for\ + \ help. Let\u2019s solve this problem step by step. This is a question related\ + \ to anatomy and orthodontics. Excess overjet is associated with Class II occlusions;\ + \ therefore, we can safely eliminate (B) from the list, as negative overjet\ + \ is often associated with Class III occlusions. Now, we need to determine the\ + \ location of the excess overjet, and that would be the upper (maxillary) lateral\ + \ incisors. Only (C) has the correct information. The answer is (C)." + - question: 'The pleura + + (A) have no sensory innervation. (B) are separated by a 2 mm space. (C) extend + into the neck. (D) are composed of respiratory epithelium.' + target: "Let's think step by step. We refer to Wikipedia articles on anatomy for\ + \ help. Let\u2019s solve this problem step by step. First, recall that the pleura\ + \ refers to the thin layer of tissue that covers the lungs and lines the interior\ + \ wall of the chest cavity. Now, let\u2019s look at each option:\nOption (A):\ + \ \u201CThe pleura have no sensory innervation.\u201D This information is not\ + \ correct. The pleura do have a sensory innervation.\nOption (B): \u201CThe\ + \ pleura are separated by a 2 mm space.\u201D This information is not correct.\ + \ There is a very thin \u201Cpotential\u201D space between the layers of the\ + \ pleura; however, it is typically filled with serous pleural fluid. \nOption\ + \ (C): \u201CThe pleura extend into the neck.\u201D This information is actuakky\ + \ true. The cervical pleura, also known as the dome of the pleuradome of the\ + \ pleura, lines the extendsiton of the pleural cavity into the neck.\nOption\ + \ (D): \u201CThe pleura are composed of respiratory epithelium.\u201D This information\ + \ is not correct. The pleaura are composed of connective tissue (CT).\nBecause\ + \ (A), (B), and (D) are all incorrect, (D) is the only correct answer. The answer\ + \ is (C)." + - question: 'What is the embryological origin of the hyoid bone? + + (A) The first pharyngeal arch (B) The first and second pharyngeal arches (C) + The second pharyngeal arch (D) The second and third pharyngeal arches' + target: "Let's think step by step. We refer to Wikipedia articles on anatomy for\ + \ help. Let\u2019s solve this problem step by step. The hyoid bone, which is\ + \ also known as the hyooid, is a a small U-shaped bone located in the anterior\ + \ neck. In its resting position, it lies between the ase of the mandible and\ + \ the third cervical vertebrae. We know that the second and the third pharyngeal\ + \ arches give rise to the horns of the hyoid bone; therefore, the embryological\ + \ origin of the hyoid bone are the second and the third pharyngeal arches\u2014\ + this information is covered in the last option (D). Therefore, we conclude that\ + \ (D) must be the correct answer. The answer is (D).\n\n" +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_anatomy diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_astronomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b89974588e8c83db8aedc80b607e25212a676592 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_astronomy.yaml @@ -0,0 +1,70 @@ +dataset_name: astronomy +description: The following are multiple choice questions (with answers) about astronomy. +fewshot_config: + sampler: first_n + samples: + - question: 'Where do most short-period comets come from and how do we know? + + (A) The Kuiper belt; short period comets tend to be in the plane of the solar + system just like the Kuiper belt. (B) The Kuiper belt; short period comets tend + to come from random directions indicating a spherical distribution of comets + called the Kuiper belt. (C) The asteroid belt; short period comets have orbital + periods similar to asteroids like Vesta and are found in the plane of the solar + system just like the asteroid belt. (D) The Oort cloud; short period comets + tend to be in the plane of the solar system just like the Oort cloud.' + target: Let's think step by step. Most short-period comets come from the Kuiper + belt, and we know because short period coments tend to be in the plane of the + solar system, just like the Kuiper belt is. The answer is (A). + - question: 'You are pushing a truck along a road. Would it be easier to accelerate + this truck on Mars? Why? (Assume there is no friction) + + (A) It would be harder since the truck is heavier on Mars. (B) It would be easier + since the truck is lighter on Mars. (C) It would be harder since the truck is + lighter on Mars. (D) It would be the same no matter where you are.' + target: "Let's think step by step. If we assume that there is no friction, the\ + \ force needed to accelerate the truck is by Newton\u2019s second law only dependent\ + \ on the mass of the truck. Hence (A), (B) and (C) are incorrect since it doesn\u2019\ + t matter that it\u2019s on Mars, and (D) is the correct answer. The answer is\ + \ (D)." + - question: 'Say the pupil of your eye has a diameter of 5 mm and you have a telescope + with an aperture of 50 cm. How much more light can the telescope gather than + your eye? + + (A) 10000 times more (B) 100 times more (C) 1000 times more (D) 10 times more' + target: Let's think step by step. The amount of light is proportional to the aperture + area $A = \pi D^2/4$ for a lens with diameter $D$, so the relative amounts of + light between the eye with diameter 5mm and the telescope with diameter 50mm + is $(50 cm)^2/(5mm)^2 = 10000$. The answer is (A). + - question: 'Why isn''t there a planet where the asteroid belt is located? + + (A) A planet once formed here but it was broken apart by a catastrophic collision. + (B) There was not enough material in this part of the solar nebula to form a + planet. (C) There was too much rocky material to form a terrestrial planet but + not enough gaseous material to form a jovian planet. (D) Resonance with Jupiter + prevented material from collecting together to form a planet.' + target: "Let's think step by step. The asteroid belt is a stellar disc consisting\ + \ of a large number of asteroids between Mars and Jupiter's orbits. The asteroids\ + \ in this belt are affected by the gravitational pull from both other asteroids\ + \ and nearby planets. Due to the strong gravitational force of Jupiter there\ + \ are resonances that give rise to low density regions of asteroids known as\ + \ the Kirkwood gap. So (B) and (C) are not correct since it\u2019s not a lack\ + \ of material that prevents a planet from being formed, and (A) is incorrect\ + \ because the Kirkwood gap would have prevented a planet from forming in the\ + \ first place, and (D) is the correct option. The answer is (D)." + - question: 'Why is Mars red? + + (A) Because the surface is covered with heavily oxidized ("rusted") minerals. + (B) Because the atmosphere scatters more light at bluer wavelengths transmitting + mostly red light. (C) Because Mars is covered with ancient lava flows which + are red in color. (D) Because flowing water on Mars''s surface altered the surface + minerals several billion years ago.' + target: 'Let''s think step by step. Option (B) is not correct because if the red + color was caused by the scattering off the atmosphere, then the earth with a + much thicker atmosphere would also look red. Options (C) and (D) are not specific + enough about why the color of the surface would be red, while (A) is correct + because it explains that the surface is red due to the rusted materials on the + surface and the red color comes from the rust. So the correct option is (A). + The answer is (A).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_astronomy diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c83d4bc8c06531a9375f52d8688f8b4a7cdb974 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_business_ethics.yaml @@ -0,0 +1,75 @@ +dataset_name: business_ethics +description: The following are multiple choice questions (with answers) about business + ethics. +fewshot_config: + sampler: first_n + samples: + - question: 'In contrast to _______, _______ aim to reward favourable behaviour by + companies. The success of such campaigns have been heightened through the use + of ___________, which allow campaigns to facilitate the company in achieving + _________ . + + (A) Buycotts, Boycotts, Blockchain technology, Charitable donations (B) Buycotts, + Boycotts, Digital technology, Increased Sales (C) Boycotts, Buyalls, Blockchain + technology, Charitable donations (D) Boycotts, Buycotts, Digital technology, + Increased Sales' + target: "Let's think step by step. We refer to Wikipedia articles on business\ + \ ethics for help. The sentence that best uses the possible options above is\ + \ \u201CIn contrast to *boycotts*, *buycotts* aim to reward favourable behavior\ + \ by companies. The success of such campaigns have been heightened through the\ + \ use of *digital technology*, which allow campaigns to facilitate the company\ + \ in achieving *increased sales*.\u201D The answer is (D)." + - question: '_______ is the direct attempt to formally or informally manage ethical + issues or problems, through specific policies, practices and programmes. + + (A) Corporate social responsibility (B) Business ethics management (C) Sustainability + (D) Environmental management' + target: Let's think step by step. We refer to Wikipedia articles on business ethics + for help. The direct attempt manage ethical issues through specific policies, + practices, and programs is business ethics management. The answer is (B). + - question: 'Three contrasting tactics that CSO''s can engage in to meet their aims + are ________ which typically involves research and communication, ________, + which may involve physically attacking a company''s operations or ________, + often involving some form of _______. + + (A) Non-violent direct action, Violent direct action, Indirect action, Boycott + (B) Indirect action, Instrumental action, Non-violent direct action, Information + campaign (C) Indirect action, Violent direct action, Non-violent direct-action + Boycott (D) Non-violent direct action, Instrumental action, Indirect action, + Information campaign' + target: "Let's think step by step. We refer to Wikipedia articles on business\ + \ ethics for help. The sentence that best uses the possible options above is\ + \ \u201CThree contrasting tactics that CSO's can engage in to meet their aims\ + \ are *indirect action*, which typically involves research and communication,\ + \ *violent direct action*, which may involve physically attacking a company's\ + \ operations or *non-violent direct action*, often involving some form of *boycott*.\u201D\ + \ The answer is (C)." + - question: 'To ensure the independence of the non-executive board members, there are + a number of steps which can be taken, which include non-executives being drawn + from _______ the company, being appointed for a _________ time period as well + as being appointed _________. + + (A) Outside, Limited, Independently (B) Inside, Limited, Intermittently (C) + Outside, Unlimited, Intermittently (D) Inside, Unlimited, Independently' + target: "Let's think step by step. We refer to Wikipedia articles on business\ + \ ethics for help. The sentence that best uses the possible options above is\ + \ \u201CTo ensure the independence of the non-executive board members, there\ + \ are a number of steps which can be taken, which include non-executives being\ + \ draw from *outside* the company, being appointed for a *limited* time period\ + \ as well as being imported *independently*. The answer is (A)." + - question: 'Beyond the business case for engaging in CSR there are a number of moral + arguments relating to: negative _______, the _______that corporations possess + and the ________ of business and society. + + (A) Externalities, Power, Independence (B) Publicity, Insubstantial resources, + Mutual dependence (C) Publicity, Power, Independence (D) Externalities, Power, + Mutual dependence' + target: "Let's think step by step. We refer to Wikipedia articles on business\ + \ ethics for help. The sentence that best uses the possible options above is\ + \ \u201CBeyond the business case for engaging the CSR there are a number of\ + \ moral arguments relating to: negative *externalities*, the *power* that corporations\ + \ possess and the *mutual independence* of business and society. The answer\ + \ is (D).\n\n" +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_business_ethics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_clinical_knowledge.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_clinical_knowledge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..008d2f870ad82e3ebee126511418c67c787c7b3b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_clinical_knowledge.yaml @@ -0,0 +1,48 @@ +dataset_name: clinical_knowledge +description: The following are multiple choice questions (with answers) about clinical + knowledge. +fewshot_config: + sampler: first_n + samples: + - question: 'Glycolysis is the name given to the pathway involving the conversion of: + + (A) glycogen to glucose-1-phosphate. (B) glycogen or glucose to fructose. (C) + glycogen or glucose to pyruvate or lactate. (D) glycogen or glucose to pyruvate + or acetyl CoA.' + target: Let's think step by step. We refer to Wikipedia articles on clinical knowledge + for help. Glycolysis is the name given to the pathway involving conversion of + glycogen or glucose to pyruvate or lactate. The answer is (C). + - question: 'What is the difference between a male and a female catheter? + + (A) Male and female catheters are different colours. (B) Male catheters are + longer than female catheters. (C) Male catheters are bigger than female catheters. + (D) Female catheters are longer than male catheters.' + target: Let's think step by step. We refer to Wikipedia articles on clinical knowledge + for help. The difference between a male and female catheter is that male catheters + tend to be longer than female catheters. The answer is (B). + - question: 'How many attempts should you make to cannulate a patient before passing + the job on to a senior colleague, according to the medical knowledge of 2020? + + (A) 4 (B) 3 (C) 2 (D) 1' + target: Let's think step by step. We refer to Wikipedia articles on clinical knowledge + for help. According to the medical protocol as of 2020, you should make two + attempts to cannulate a patient before passing the job on to a more-senior practitioner. + The answer is (C). + - question: 'In the assessment of the hand function which of the following is true? + + (A) Abduction of the thumb is supplied by spinal root T2 (B) Opposition of the + thumb by opponens policis is supplied by spinal root T1 (C) Finger adduction + is supplied by the median nerve (D) Finger abduction is mediated by the palmar + interossei' + target: Let's think step by step. We refer to Wikipedia articles on clinical knowledge + for help. Of all the options, it is only true that the opposition of the thumb + by opponens pollicis is supplied by spinal root T1. The answer is (B). + - question: 'The energy for all forms of muscle contraction is provided by: + + (A) ATP. (B) ADP. (C) phosphocreatine. (D) oxidative phosphorylation.' + target: 'Let''s think step by step. We refer to Wikipedia articles on clinical + knowledge for help. The energy for muscular contraction is provided by ATP (adenosine + triphosphate), which is the powerhouse of the cell. The answer is (A).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_clinical_knowledge diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28f7f989b86b28cccc5e9128268e8ea6bc80e662 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_biology.yaml @@ -0,0 +1,75 @@ +dataset_name: college_biology +description: The following are multiple choice questions (with answers) about college + biology. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following represents an accurate statement concerning arthropods? + + (A) They possess an exoskeleton composed primarily of peptidoglycan. (B) They + possess an open circulatory system with a dorsal heart. (C) They are members + of a biologically unsuccessful phylum incapable of exploiting diverse habitats + and nutrition sources. (D) They lack paired, jointed appendages.' + target: Let's think step by step. Peptidoglycan is known to comprise the plasma + membrane of most bacteria, rather than the exoskeleton of arthropods, which + is made of chitin, which rules out (A). The answer (C) is false because arthropods + are a highly successful phylum. Likewise, arthropods have paired, jointed appendages, + which rules out (D). The only remaining option is (B), as arthropods have an + open circulatory system with a dorsal tubular heart. The answer is (B). + - question: 'In a given population, 1 out of every 400 people has a cancer caused by + a completely recessive allele, b. Assuming the population is in Hardy-Weinberg + equilibrium, which of the following is the expected proportion of individuals + who carry the b allele but are not expected to develop the cancer? + + (A) 1/400 (B) 19/400 (C) 20/400 (D) 38/400' + target: "Let's think step by step. According to the Hardy Weinberg Law, $p^2 +\ + \ 2 p q + q^2 = 1$, and $p + q = 1$ where $p$ is the frequency of the dominant\ + \ allele, $q$ is the frequency of the recessive allele, and $p^2$, $q^2$, and\ + \ $2pq$ are the frequencies of dominant homozygous, recessive homozygous, and\ + \ heterozygous individuals, respectively. \u200BThe frequency of the recessive\ + \ allele (q) is $\\sqrt{\frac{1}{400}} = 0.05$. We have $p = 1 - q = 0.95$.\ + \ The frequency of heterozygous individuals is $2pq = 2 \\cdot 0.05 \\cdot 0.95\ + \ = 0.095$. The number of heterozygous individuals is equal to the frequency\ + \ of heterozygous individuals times the size of the population, or $0.095 *\ + \ 400 = 38$. So we end up with 38/400. The answer is (D)." + - question: 'According to the pressure-flow model of movement of phloem contents, photosynthate + movement from source to sink is driven by + + (A) an ATP-dependent pressure-flow pump (B) a water-pressure potential gradient + (C) transpiration (D) apoplastic diffusion' + target: Let's think step by step. It is a gradient in water pressure that induces + the movement of phloem content, which refers to answer (B). The mechanism of + movement does not rely on metabolism, which rules out (A). Transpiration refers + to the exhalation of water vapor through plant stomata, and is also not related, + which rules out (C). While the apoplastic pathway is one of two main pathways + for water transport in plants, it is not central to the pressure flow model, + which rules out (D). The answer is (B). + - question: 'Which of the following contain DNA sequences required for the segregation + of chromosomes in mitosis and meiosis? + + (A) Telomeres (B) Centromeres (C) Nucleosomes (D) Spliceosomes' + target: Let's think step by step. The genetic material in Telomeres is not used, + which rules out (A). Nucleosomes are the repeating subunit that comprises chromatin + packed in a cell nucleus, and do not specifically refer to DNA sequences necessary + for segregating chromosomes in cell division, which rules out (C). A spliceosome + is a large ribonucleoprotein that removes introns from transcribed pre-mRNA + rather than governing chromosome segregation. Centromeres are directly responsible + for segregating chromosomes in cell division. The answer is (B). + - question: 'The presence of homologous structures in two different organisms, such + as the humerus in the front limb of a human and a bird, indicates that + + (A) the human and bird are polyphyletic species (B) a human''s and bird''s evolution + is convergent (C) the human and bird belong to a clade (D) the human and bird + developed by analogy' + target: 'Let''s think step by step. Polyphyletic species are organisms that are + grouped due to having similar characteristics but which do not have a common + ancestor. This is not the case for humans and birds, which rules out (A). Convergent + evolution refers to the indepdendent development of similar features in different + species at different periods, which is also not the case for humans and birds, + which rules out (B). Analogy refers to the superficial resemblance of structures + that have different origins, which is not the case for the human and bird forearms, + which rules out (D). Humans and birds do belong to the same clade - a group + of organisms composed of a common ancestor. The answer is (C).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_college_biology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a8cfc9e4436f1dc50b75efdf84a0b6f0625f2bf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_chemistry.yaml @@ -0,0 +1,49 @@ +dataset_name: college_chemistry +description: The following are multiple choice questions (with answers) about college + chemistry. +fewshot_config: + sampler: first_n + samples: + - question: "3 Cl\u2212(aq) + 4 CrO_4^2\u2212(aq) + 23 H+(aq) \u2192 3 HClO2(aq) +\ + \ 4 Cr3+(aq) + 10 H2O(l). In the reaction shown above, Cl\u2212(aq) behaves\ + \ as\n(A) an acid (B) a base (C) a catalyst (D) a reducing agent" + target: Let's think step by step. A molecule that behaves as a base accepts an + H+ ion (or proton) from another molecule, whereas a molecule that behaves as + an acid donates an H+ ion (or proton) to another molecule. Neither of these + is the case for Cl in this reaction, which rules out (A) and (B). A catalyst + is a substance that only accelerates a reaction without itself undergoing chemical + change, which is not the case here. This rules out (C). Instead, the $Cl^{-} + molecules carry a negative charge, which they donate in the reaction to form + 3 HClO2. This is the behavior of a reducing agent, or (D). The answer is (D). + - question: 'Which of the following statements about the lanthanide elements is NOT + true? + + (A) The most common oxidation state for the lanthanide elements is +3. (B) Lanthanide + complexes often have high coordination numbers (> 6). (C) All of the lanthanide + elements react with aqueous acid to liberate hydrogen. (D) The atomic radii + of the lanthanide elements increase across the period from La to Lu.' + target: Let's think step by step. The atomic radii of the lanthanide elements + in fact decrease across the period from La to Lu. Options (A), (B), and (C) + are all true. This means that only (D) is NOT true. The answer is (D). + - question: 'Which of the following lists the hydrides of group-14 elements in order + of thermal stability, from lowest to highest? + + (A) PbH4 < SnH4 < GeH4 < SiH4 < CH4 (B) PbH4 < SnH4 < CH4 < GeH4 < SiH4 (C) + CH4 < SiH4 < GeH4 < SnH4 < PbH4 (D) CH4 < PbH4 < GeH4 < SnH4 < SiH4' + target: Let's think step by step. The thermal stability of group-14 hydrides decreases + as we move from the top of group 14 to the bottom. The order of elements in + the group from top to bottom is C, Si, Ge, Sn, Pb. Therefore in order of increasing + thermal stability we have PbH4, SnH4, GeH4, SiH4, and CH4, or answer (A). The + answer is (A). + - question: "Predict the number of lines in the EPR spectrum of a solution of 13C-labelled\ + \ methyl radical (13CH3\u2022), assuming the lines do not overlap.\n(A) 4 (B)\ + \ 3 (C) 6 (D) 24 (E) 8" + target: "Let's think step by step. The electron paramagnetic resonance spectrum\ + \ will be split by two forms of interactions. The first is the hyperfine interaction\ + \ with the 13C (nuclear spin $I = \nrac{1}{2}$) which will split the spectrum\ + \ into 2 lines. This will be further split into 4 lines by the interaction with\ + \ three equivalent 1H nuclei. The total number of lines is therefore $2 \\cdot\ + \ 4 = 8$. The answer is (E).\n\n" +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_college_chemistry diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5eccde7c6be9620f63bbf3b1de42f32c8e121539 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_computer_science.yaml @@ -0,0 +1,180 @@ +dataset_name: college_computer_science +description: The following are multiple choice questions (with answers) about college + computer science. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following regular expressions is equivalent to (describes + the same set of strings as) (a* + b)*(c + d)? + + (A) a*(c + d)+ b(c + d) + + (B) a*(c + d)* + b(c + d)* + + (C) a*(c + d)+ b*(c + d) + + (D) (a + b)*c +(a + b)*d' + target: 'Let''s think step by step. We know that: + + 1. (X* + Y)* = (X + Y)* + + 2. X(Y + Z)? = XY + XZ + + Using equation 1 we can rewrite (a* + b)*(c + d)? as: + + 3. (a + b)*(c + d)? + + Using equation 2 we can rewrite equation 3 as: + + (a + b)*c + (a + b)*d The answer is (D).' + - question: 'The Singleton design pattern is used to guarantee that only a single instance + of a class may be instantiated. Which of the following is (are) true of this + design pattern? + + I. The Singleton class has a static factory method to provide its instance. + + II. The Singleton class can be a subclass of another class. + + III. The Singleton class has a private constructor. + + (A) I only + + (B) II only + + (C) III only + + (D) I, II, and III' + target: 'Let''s think step by step. Statement I is a correct statement about a + Singleton, because a Singleton restricts instantiation to a single, static method. + Statement II is also correct, because there is no inherent restriction regarding + the inheritance of a Singleton. Statement III is also correct, because a Singletons + must be instantiated only once, so its constructor is made private to prevent + any construction except via its static factory method. + + Given these facts, statements I, II, and III are all correct. The answer is + (D).' + - question: 'A certain pipelined RISC machine has 8 general-purpose registers R0, R1, + . . . , R7 and supports the following operations: + + ADD Rs1, Rs2, Rd (Add Rs1 to Rs2 and put the sum in Rd) + + MUL Rs1, Rs2, Rd (Multiply Rs1 by Rs2 and put the product in Rd) + + An operation normally takes one cycle; however, an operation takes two cycles + if it produces a result required by the immediately following operation in an + operation sequence. + + Consider the expression AB + ABC + BC, where variables A, B, C are located in + registers R0, R1, R2. If the contents of these three registers must not be modified, + what is the minimum number of clock cycles required for an operation sequence + that computes the value of AB + ABC + BC? + + (A) 5 (B) 6 (C) 7 (D) 8' + target: 'Let''s think step by step. First, we are given that A is in R0, B is + in R1, and C is in R2. + + Next, we can see that we must compute three multiplies (AB, BC, and ABC) and + two adds (AB + ABC, (AB + ABC) + BC) to compute our final answer, resulting + in a minimum of five clock cycles. + + Next, we can see that there is no way to avoid at least one pipeline stall when + computing our final answer, because to compute our final sum we must wait at + least one cycle for the results from the previous stage to be ready. Thus, our + minimum number of cycles must be 6. + + We can verify that we can create a solution that requires only six cycles as + follows: + + compute AB: MUL R0, R1, R3 + + compute BC: MUL R1, R2, R4 + + compute ABC: MUL R3, R4, R5 + + compute AB + BC: ADD R3, R4, R6 + + STALL + + compute AB + ABC + BC: ADD R5, R6, R7 + + So there are 6 cycles. The answer is (B).' + - question: 'A compiler generates code for the following assignment statement. + + G := (A + B) * C - (D + E) * F + + The target machine has a single accumulator and a single-address instruction + set consisting of instructions load, store, add, subtract, and multiply. For + the arithmetic operations, the left operand is taken from the accumulator and + the result appears in the accumulator. The smallest possible number of instructions + in the resulting code is + + (A) 5 (B) 6 (C) 7 (D) 9' + target: 'Let''s think step by step. We can compute the final answer with the following + sequence of operations: + + 1. LOAD D (accumulator = D) + + 2. ADD E (accumulator = D+E) + + 3. MUL F (accumulator = (D+E)*F) + + 4. STORE X (X = (D+E)*F) + + 5. LOAD A (accumulator = A) + + 6. ADD B (accumulator = A+B) + + 7. MUL C (accumulator = (A+B)*C) + + 8. SUB X (accumulator = (A+B)*C - (D+E)*F) + + 9. STORE G (G = (A+B)*C - (D+E)*F) + + This sequence takes 9 instructions. The answer is (D).' + - question: 'Consider a computer design in which multiple processors, each with a private + cache memory, share global memory using a single bus. This bus is the critical + system resource. Each processor can execute one instruction every 500 nanoseconds + as long as memory references are satisfied by its local cache. When a cache + miss occurs, the processor is delayed for an additional 2,000 nanoseconds. During + half of this additional delay, the bus is dedicated to serving the cache miss. + During the other half, the processor cannot continue, but the bus is free to + service requests from other processors. On average, each instruction requires + 2 memory references. On average, cache misses occur on 1 percent of references. + What proportion of the capacity of the bus would a single processor consume, + ignoring delays due to competition from other processors? + + (A) 1/50 (B) 1/27 (C) 1/25 (D) 2/27' + target: 'Let''s think step by step. We know that each instruction requires two + memory references per instruction, and that there is an average cache miss rate + of one percent. + + Thus a given processor has: + + (1 cache miss / 100 references) * (2 references / instruction) = + + (2 cache misses / 100 instructions), so: + + misses_per_instruction = 1 cache miss / 50 instructions. + + Next, we know that each instruction requires 500 nanoseconds when there is no + cache miss, and 500 + 2000 = 2500 nanoseconds when there is a cache miss. Thus: + + 50 instructions / (49 * 500) + (1 * 2500) nanoseconds, so: + + instructions_per_ns = 50 instructions / 27000 nanoseconds. + + Now, we know that each cache miss locks the bus for half of the 2000 nanosecond + cache miss delay, or 1000 nanoseconds, so: + + lock_ns_per_miss = 1000 nanoseconds / cache miss. + + Thus we can see that on average a single processor will lock the bus for: + + lock_ns_per_miss * misses_per_instruction * instructions_per_ns = + + (1000 nanoseconds / cache miss) * (1 cache miss / 50 instructions) * (50 instructions + / 27000 nanoseconds) = 1000 * (1/50) * (50/27000) = 1000/27000 = 1/27. The answer + is (B).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_college_computer_science diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5552cc35bc55ffab1b53538ee2605778e6f215d6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_mathematics.yaml @@ -0,0 +1,73 @@ +dataset_name: college_mathematics +description: The following are multiple choice questions (with answers) about college + mathematics. +fewshot_config: + sampler: first_n + samples: + - question: 'Let V be the set of all real polynomials p(x). Let transformations T, + S be defined on V by T:p(x) -> xp(x) and S:p(x) -> p''(x) = d/dx p(x), and interpret + (ST)(p(x)) as S(T(p(x))). Which of the following is true? + + (A) ST = 0 (B) ST = T (C) ST = TS (D) ST - TS is the identity map of V onto + itself.' + target: "Let's think step by step. For a given polynomial $p$ we have\n\\[ST(p)\ + \ = (xp(x))\u2019 = p(x) + xp\u2019(x)\\]\nand\n\\[TS(p) = xp\u2019(x).\\]\n\ + Hence \\[ST(p) - TS(p) = p(x) + xp\u2019(x) - xp\u2019(x).\\] The answer is\ + \ (D)." + - question: 'Suppose that f(1 + x) = f(x) for all real x. If f is a polynomial and + f(5) = 11, then f(15/2) + + (A) -11 (B) 0 (C) 11 (D) 33/2' + target: Let's think step by step. The only polynomial so that $f(1 + x) = f(x)$ + is a constant polynomial. Hence $f(5) = 11 = f(15/2)$. The answer is (C). + - question: 'Let A be a real 2x2 matrix. Which of the following statements must be + true? + + I. All of the entries of A^2 are nonnegative. + + II. The determinant of A^2 is nonnegative. + + III. If A has two distinct eigenvalues, then A^2 has two distinct eigenvalues. + + (A) I only (B) II only (C) III only (D) II and III only' + target: 'Let''s think step by step. We have \[ det(A^2) = (det(A))^2 \geq 0,\] + hence II holds. + + III is false: as a counterexample take a diagonal matrix with -1 and 1 on the + diagonal. Then $A^2$ is the identity matrix. The answer is (B).' + - question: 'Let A be the set of all ordered pairs of integers (m, n) such that 7m + + 12n = 22. What is the greatest negative number in the set B = {m + n : (m, + n) \in A}? + + (A) -5 (B) -4 (C) -3 (D) -2' + target: Let's think step by step. We have 12n = 22 - 7m and one of the solutions + is $m = -2$, $n = 3$. Then $m + n = 1$, hence we need to look for smaller $m$ + in order to make $m + n$ negative. The next solution is $m = -14$ and $n = 10$. + For smaller $m$ we have $m + n$ smaller than $-4$. The answer is (B). + - question: 'A tank initially contains a salt solution of 3 grams of salt dissolved + in 100 liters of water. A salt solution containing 0.02 grams of salt per liter + of water is sprayed into the tank at a rate of 4 liters per minute. The sprayed + solution is continually mixed with the salt solution in the tank, and the mixture + flows out of the tank at a rate of 4 liters per minute. If the mixing is instantaneous, + how many grams of salt are in the tank after 100 minutes have elapsed? + + (A) 2 (B) 2 - e^-2 (C) 2 + e^-2 (D) 2 + e^-4' + target: "Let's think step by step. For all $t \\in \\mathbb{R}$, let $s(t)$ denote\ + \ the number grams of salt in the tank at the $t$ minute mark. Then $s(0) =\ + \ 3$.\nWe use $s$ and $s(t)$ interchangeably. We also use $s^{\\prime}$ and\ + \ $s^{\\prime}(t)$ interchangeably. The solution sprayed into the tank adds\ + \ $(0.02) 4=2 / 25$ grams of salt per minute. There are always 100 liters of\ + \ liquid in the tank, containing $s$ grams of salt. So the density of salt in\ + \ the tank is $s / 100$ grams per liter. The flow of water out of the tank therefore\ + \ subtracts $4(s / 100)=s / 25$ grams of salt per minute. Then, for all $t \\\ + in \\mathbb{R}$, we have $s^{\\prime}(t)=(2 / 25)-(s / 25)=(2-s) / 25$, and\ + \ so $[s(t)=2] \\Rightarrow\\left[s^{\\prime}(t)=0\right]$. For all $t \\in\ + \ \\mathbb{R}$,\n$$\n\frac{d}{d t}[\\ln (s-2)]=\frac{s^{\\prime}}{s-2}=\frac{-1}{25}=\f\ + rac{d}{d t}\\left[-\frac{t}{25}\right] .\n$$\nChoose $C \\in \\mathbb{R}$ such\ + \ that, for all $t \\in \\mathbb{R}, \\ln ((s(t)-2))=-[t / 25]+C$. Let $K:=e^{C}$.\ + \ Then, for all $t \\in \\mathbb{R}$, we have $(s(t))-2=K e^{-t / 25}$, and\ + \ so $s(t)=2+K e^{-t / 25}$. Then $3=s(0)=2+K e^{0}=2+K$, so $K=1$. Then $s(100)=2+K\ + \ e^{-100 / 25}=2+1 \\cdot e^{-4}=2+e^{-4}$. The answer is (D).\n\n" +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_college_mathematics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7eac0bab3f9286469b44e815c8a2090fc5ff0832 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_medicine.yaml @@ -0,0 +1,68 @@ +dataset_name: college_medicine +description: The following are multiple choice questions (with answers) about college + medicine. +fewshot_config: + sampler: first_n + samples: + - question: 'An expected side effect of creatine supplementation is: + + (A) muscle weakness. (B) gain in body mass. (C) muscle cramps. (D) loss of electrolytes.' + target: Let's think step by step. We refer to Wikipedia articles on medicine for + help. Creatine supplementation is a dietary supplement that results in body + mass gain. The answer is (B). + - question: 'Which of the following is not a true statement? + + (A) Muscle glycogen is broken down enzymatically to glucose-1-phosphate (B) + Elite endurance runners have a high proportion of Type I fibres in their leg + muscles (C) Liver glycogen is important in the maintenance of the blood glucose + concentration (D) Insulin promotes glucose uptake by all tissues in the body' + target: "Let's think step by step. We refer to Wikipedia articles on medicine\ + \ for help. Let\u2019s solve this step by step and go over each choice: \n(A)\ + \ \u201CMuscle glycogen is broken down enzymatically to glucose-1-phosphate\u201D\ + : This is a correct statement.\n(B) \u201CElite endurance runners have a high\ + \ proportion of Type I fibres in their leg muscles\u201D: This is a correct\ + \ statement.\n(C) \u201CLiver glycogen is important in the maintenance of the\ + \ blood glucose concentration\u201D: This is a correct statement. \n(D) \u201C\ + Insulin promotes glucose uptake by all tissues in the body\u201D: This is not\ + \ a correct statement, because insulin promotes glucose uptake by the liver,\ + \ adipose tissue, and muscle, but not all tissues. For instance, the tissues\ + \ in the brain and red blood cells are not affected by insulin. The answer is\ + \ (D)." + - question: "A high school science teacher fills a 1 liter bottle with pure nitrogen\ + \ and seals the lid. The pressure is 1.70 atm, and the room temperature is 25\xB0\ + C. Which two variables will both increase the pressure of the system, if all\ + \ other variables are held constant?\n(A) Increasing temperature, increasing\ + \ moles of gas (B) Increasing temperature, increasing volume (C) Decreasing\ + \ volume, decreasing temperature (D) Decreasing moles of gas, increasing volume" + target: 'Let''s think step by step. We refer to Wikipedia articles on medicine + for help. The relevant equation for this is the ideal gas law: PV=nRT. To increase + the pressure of the system (P), then either n (number of moles of the gas) or + T (temperature) have to increase. The answer is (A).' + - question: 'In a genetic test of a newborn, a rare genetic disorder is found that + has X-linked recessive transmission. Which of the following statements is likely + true regarding the pedigree of this disorder? + + (A) All descendants on the maternal side will have the disorder. (B) Females + will be approximately twice as affected as males in this family. (C) All daughters + of an affected male will be affected. (D) There will be equal distribution of + males and females affected.' + target: "Let's think step by step. We refer to Wikipedia articles on medicine\ + \ for help. Let\u2019s solve this step by step. Let's recall first that females\ + \ have two X chromosomes, while males have one X and one Y chromosome. This\ + \ is an important fact we need to know before answering this question. \nBecause\ + \ a male can only pass his only one X chromosome to a daughter, if he is affected\ + \ by this rare genetic disorder, then we know for sure that he will pass this\ + \ rare genetic disorder to all his future-born daughters. Therefore, \u201C\ + (C): All daughters of an affected male will be affected\u201D is a correct statement.\ + \ The answer is (C)." + - question: 'Glucose is transported into the muscle cell: + + (A) via protein transporters called GLUT4. (B) only in the presence of insulin. + (C) via hexokinase. (D) via monocarbylic acid transporters.' + target: 'Let''s think step by step. We refer to Wikipedia articles on medicine + for help. Glucose (also known as the blood sugar) is the main sugar found in + the human body. It is transported into the muscle cell via diffusion through + protein transporters called GLUT4. The answer is (A).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_college_medicine diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aa158a9f3c3e9e420608970a3ea91e744a042a6a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_college_physics.yaml @@ -0,0 +1,61 @@ +dataset_name: college_physics +description: The following are multiple choice questions (with answers) about college + physics. +fewshot_config: + sampler: first_n + samples: + - question: 'A refracting telescope consists of two converging lenses separated by + 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification + of the telescope is + + (A) 4 (B) 5 (C) 6 (D) 20' + target: Let's think step by step. In a refracting telescope, if both lenses are + converging, the focus of both lenses must be between the two lenses, and thus + the focal lengths of the two lenses must add up to their separation. Since the + focal length of one lens is 20 cm, the focal length of the other must be 80 + cm. The magnification is the ratio of these two focal lengths, or 4. The answer + is (A). + - question: 'The muon decays with a characteristic lifetime of about 10^-6 second into + an electron, a muon neutrino, and an electron antineutrino. The muon is forbidden + from decaying into an electron and just a single neutrino by the law of conservation + of + + (A) charge (B) mass (C) energy and momentum (D) lepton number' + target: Let's think step by step. Lepton number must be conserved, meaning the + total number of leptons minus the number of antileptons. If a muon decays into + an electron and a single neutrino, the total lepton number would go from one + to two, violating lepton number conservation. The answer is (D). + - question: 'One end of a Nichrome wire of length 2L and cross-sectional area A is + attached to an end of another Nichrome wire of length L and cross- sectional + area 2A. If the free end of the longer wire is at an electric potential of 8.0 + volts, and the free end of the shorter wire is at an electric potential of 1.0 + volt, the potential at the junction of the two wires is most nearly equal to + + (A) 2.4 V (B) 3.3 V (C) 4.5 V (D) 5.7 V' + target: Let's think step by step. This is a simple voltage divider problem, where + the longer wire has a resistance four times that of the shorter end. So the + voltage divider ratio is 1 / 5, meaning that the potential in the middle is + 1.0 V + (8.0 V - 1.0 V) * 1/5 = 2.4 V. The answer is (A). + - question: 'A refracting telescope consists of two converging lenses separated by + 100 cm. The eye-piece lens has a focal length of 20 cm. The angular magnification + of the telescope is + + (A) 4 (B) 5 (C) 6 (D) 20' + target: Let's think step by step. In a refracting telescope, if both lenses are + converging, the focus of both lenses must be between the two lenses, and thus + the focal lengths of the two lenses must add up to their separation. Since the + focal length of one lens is 20 cm, the focal length of the other must be 80 + cm. The magnification is the ratio of these two focal lengths, or 4. The answer + is (A). + - question: 'For which of the following thermodynamic processes is the increase in + the internal energy of an ideal gas equal to the heat added to the gas? + + (A) Constant temperature (B) Constant volume (C) Constant pressure (D) Adiabatic' + target: 'Let''s think step by step. Heat added to the gas can go into the gases + internal energy or work done against an external force. However, if the volume + of the gas container is constant, no work will be done (since work is pressure + times change in volume). So, at constant volume, all of the heat goes into the + internal energy. The answer is (B).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_college_physics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_computer_security.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a7c5cb2d5d1a2943b5b1ad6c17bf5ce0a324f5a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_computer_security.yaml @@ -0,0 +1,50 @@ +dataset_name: computer_security +description: The following are multiple choice questions (with answers) about computer + security. +fewshot_config: + sampler: first_n + samples: + - question: 'SHA-1 has a message digest of + + (A) 160 bits (B) 512 bits (C) 628 bits (D) 820 bits' + target: Let's think step by step. Since SHA-1 is a hash function which takes an + question and produces a 160-bit (20-byte) hash value, its message digest is 160 + bits. The answer is (A). + - question: "_____________ can modify data on your system \u2013 so that your system\ + \ doesn\u2019t run correctly or you can no longer access specific data, or it\ + \ may even ask for ransom in order to give your access.\n(A) IM \u2013 Trojans\ + \ (B) Backdoor Trojans (C) Trojan-Downloader (D) Ransom Trojan" + target: Let's think step by step. The system is asking for trojans, which are + for ransom, which means ransom trojan. The answer is (D). + - question: 'What is ethical hacking? + + (A) "Hacking" ethics so they justify unintended selfish behavior (B) Hacking + systems (e.g., during penetration testing) to expose vulnerabilities so they + can be fixed, rather than exploited (C) Hacking into systems run by those whose + ethics you disagree with (D) A slang term for rapid software development, e.g., + as part of hackathons' + target: Let's think step by step. Ethical hacking is a process of detecting vulnerabilities + in an application, system, or organization's infrastructure that an attacker + can use to exploit an individual or organization. They use this process to prevent + cyberattacks and security breaches by lawfully hacking into the systems and + looking for weak points. The answer is (B). + - question: 'The ____________ is anything which your search engine cannot search. + + (A) Haunted web (B) World Wide Web (C) Surface web (D) Deep Web' + target: "Let's think step by step. The search engine searches on the Surface Web,\ + \ which is the portion of the world wide web which is visible so (B,C) are wrong.\ + \ The Haunted Web doesn\u2019t correspond to an internet concept. The Deep Web\ + \ is the part of the World Wide Web which is not indexed. The answer is (D)." + - question: 'Exploitation of the Heartbleed bug permits + + (A) overwriting cryptographic keys in memory (B) a kind of code injection (C) + a read outside bounds of a buffer (D) a format string attack' + target: 'Let''s think step by step. The Heartbleed Bug is a serious vulnerability + in the popular OpenSSL cryptographic software library. Heartbleed resulted from + improper question validation (due to a missing bounds check) in the implementation + of the TLS heartbeat extension. The vulnerability was classified as a buffer + over-read, a situation where more data can be read than should be allowed. The + answer is (C).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_computer_security diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4757faf1d490e312ea70a3d4dbe291459366471 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_conceptual_physics.yaml @@ -0,0 +1,49 @@ +dataset_name: conceptual_physics +description: ' + + The following are multiple choice questions (with answers) about conceptual physics.' +fewshot_config: + sampler: first_n + samples: + - question: 'Colors in a soap bubble result from light + + (A) converted to a different frequency (B) deflection (C) interference (D) polarization' + target: Let's think step by step. In a soap bubble film, the light bounces between + the two soap-air interfaces many times, interfering with itself constructively + or destructively depending on the width of the film. This results in different + colors being visible. The answer is (C). + - question: 'Compared with the mass of a uranium atom undergoing fission, the combined + masses of the products after fission are + + (A) less (B) more (C) the same (D) zero' + target: Let's think step by step. Fission releases energy, which comes from the + rest mass of its initial nucleus. Thus the mass of the products is less than + the mass of the reactant uranium nucleus. The answer is (A). + - question: 'Things that are equivalent according to the equivalence principle are + + (A) space and time. (B) a traveling twin and a stay-at-home twin. (C) gravity + and acceleration. (D) mass and energy.' + target: "Let's think step by step. Einstein\u2019s famous equivalence principle\ + \ states that gravity and acceleration are equivalent. The answer is (C)." + - question: 'Which of these three elements has the most mass per nucleon? + + (A) Hydrogen (B) Iron (C) Uranium (D) Same in each' + target: Let's think step by step. Due to nuclear binding energy, the mass of an + atomic nucleus is less than the sum of individual masses of the free constituent + protons and neutrons; this is known as the mass defect. Hydrogen has no mass + defect because it has only a single nucleon, so it will have the most mass per + nucleon. The answer is (A). + - question: 'A model airplane flies slower when flying into the wind and faster with + wind at its back. When launched at right angles to the wind a cross wind its + groundspeed compared with flying in still air is + + (A) the same (B) greater (C) less (D) either greater or less depending on wind + speed' + target: "Let's think step by step. The plane\u2019s speed in the direction of\ + \ the wind is greater than it would be in the absence of wind, and its direction\ + \ orthogonal to the wind is the same as it would be in the absence of the wind.\ + \ The total speed, which is these two components added in quadrature, is thus\ + \ greater than the speed in still air. The answer is (B).\n\n" +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_conceptual_physics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_econometrics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4282345ce0f2e8bab97a80413fbd2b796a7fe3e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_econometrics.yaml @@ -0,0 +1,87 @@ +dataset_name: econometrics +description: The following are multiple choice questions (with answers) about econometrics. +fewshot_config: + sampler: first_n + samples: + - question: 'Suppose now that a researcher wishes to use information criteria to determine + the optimal lag length for a VAR. 500 observations are available for the bi-variate + VAR, and the values of the determinant of the variance-covariance matrix of + residuals are 0.0336, 0.0169, 0.0084, and 0.0062 for 1, 2, 3, and 4 lags respectively. + What is the optimal model order according to Akaike''s information criterion? + + (A) 1 lag (B) 2 lags (C) 3 lags (D) 4 lags' + target: "Let's think step by step. We refer to Wikipedia articles on econometrics\ + \ for help. Let\u2019s solve this problem step by step. First of all, let\u2019\ + s recall that for a given set of data, Akaike's information criterion (AIC)\ + \ allows us to measure how well a statistical model fits the data; it is an\ + \ estimator of prediction error. Here in this problem we will need to use the\ + \ formula ln(det(sigma_hat)) + (2 * k / T) to determine the values of Akaike\u2019\ + s criterion, where ln denotes the natural log function, det the determinant\ + \ function, k the total number of parameters in total (across both equations),\ + \ and T the number of observations (which, in this case, is equal to 500). For\ + \ 1 lag, the number of parameters in total is equal to 6; for 2 lags, it is\ + \ 10; for 3 lags, it is 14; and for 4 lags, it is 18. Now, let\u2019s calculate\ + \ the values of the criterion for each lag:\n(A) 1 lag: ln(0.0336) + (2 * 6\ + \ / 500) = ln(0.0336) + (12 / 500) = -3.369\n(B) 2 lags: ln(0.0169) + (2 * 10\ + \ / 500) = ln(0.0169) + (20 / 500) = -4.040\n(C) 3 lags: ln(0.0084) + (2 * 14\ + \ / 500) = ln(0.0084) + (28 / 500) =-4.724\n(D) 4 lags: ln(0.0062) + (2 * 18\ + \ / 500) = ln(0.0062) + (36 / 500) =-5.011\nBecause the optimal model order\ + \ according to AIC minimizes the information criterion, the answer should be\ + \ the one with the lowest value. In this case, (D) has the lowest value. The\ + \ answer is (C)." + - question: 'Consider the following AR(1) model with the disturbances having zero mean + and unit variance + + yt = 0.2 + 0.4 yt-1 + ut + + The (unconditional) mean of y will be given by + + (A) 0.2 (B) 0.4 (C) 0.5 (D) 0.33' + target: "Let's think step by step. We refer to Wikipedia articles on econometrics\ + \ for help. Let\u2019s solve this problem step by step. If we have a an AR(1)\ + \ model with the disturbances having zero mean and unit variance, then the unconditional\ + \ mean of y is equal to the following:\nunconditional mean of y = (the intercept\ + \ term) / (1 - autoregressive coefficient)\nWe know that the intercept term\ + \ is 0.2 and the autoregressive coefficient is 0.4; thus, we have:\nunconditional\ + \ mean of y = (0.2) / (1 - 0.4) = (0.2) / (0.6) = 2 / 6 = 1 / 3, which is approximately\ + \ 0.33. That means that the answer should be (D) 0.33. The answer is (D)." + - question: 'What would be then consequences for the OLS estimator if heteroscedasticity + is present in a regression model but ignored? + + (A) It will be biased (B) It will be inconsistent (C) It will be inefficient + (D) All of (a), (b) and (c) will be true.' + target: Let's think step by step. We refer to Wikipedia articles on econometrics + for help. Heteroscedasticity refers to the condition where the variance of the + error terms is not constant across multiple observations. If heteroscedasticity + is present in a regression model, then the coefficient estimates in the OLS + estimator will be not only unbiased and consistent but also inefficient. Because + (A) and (B) are incorrect choices and (C) is a correct choice, (D) cannot be + the right answer. Ultimately, (C) is the only true choice. The answer is (C). + - question: 'Suppose that a test statistic has associated with it a p-value of 0.08. + Which one of the following statements is true? + + (i) If the size of the test were exactly 8%, we would be indifferent between + rejecting and not rejecting the null hypothesis + + (ii) The null would be rejected if a 10% size of test were used + + (iii) The null would not be rejected if a 1% size of test were used + + (iv) The null would be rejected if a 5% size of test were used. + + (A) (ii) and (iv) only (B) (i) and (iii) only (C) (i), (ii), and (iii) only + (D) (i), (ii), (iii), and (iv).' + target: "Let's think step by step. We refer to Wikipedia articles on econometrics\ + \ for help. Let\u2019s reason about each of the options.\n(i) is a true statement.\n\ + (ii) is a true statement.\n(iii) is a true statement.\n(iv) is not a true statement.\ + \ Thus, (i), (ii), and (iii) are true. The answer is (C)." + - question: 'For a stationary autoregressive process, shocks will + + (A) Eventually die away (B) Persist indefinitely (C) Grow exponentially (D) + Never occur' + target: 'Let''s think step by step. We refer to Wikipedia articles on econometrics + for help. This is a formal logic problem about stationally process. For a stationary + autoregressive process, shocks will eventually die away. The answer is (A).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_econometrics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..305d2340c5ffa69761ef8dc2ab128849e571bbf8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_electrical_engineering.yaml @@ -0,0 +1,47 @@ +dataset_name: electrical_engineering +description: ' + + The following are multiple choice questions (with answers) about electrical engineering.' +fewshot_config: + sampler: first_n + samples: + - question: "A point pole has a strength of 4\u03C0 * 10^-4 weber. The force in newtons\ + \ on a point pole of 4\u03C0 * 1.5 * 10^-4 weber placed at a distance of 10\ + \ cm from it will be\n(A) 15 N. (B) 20 N. (C) 7.5 N. (D) 3.75 N." + target: "Let's think step by step. The force between two point poles is given\ + \ by m_1m_2/(mu_0 4 \\pi r^2), in analogy to Coulomb\u2019s law. Plugging in\ + \ the values given in the question, we calculate that the force is approximately\ + \ 15 N. The answer is (A)." + - question: 'The coil of a moving coil meter has 100 turns, is 40 mm long and 30 mm + wide. The control torque is 240*10-6 N-m on full scale. If magnetic flux density + is 1Wb/m2 range of meter is + + (A) 1 mA. (B) 2 mA. (C) 3 mA. (D) 4 mA.' + target: Let's think step by step. The torque on a coil in a uniform magnetic field + is given by BANI, where B is the magnetic flux density, A is the area of the + coil, N is the number of turns, and I is the current. So we have that I = (Torque)/(BAN), + or 240e-6/(1200e-6 * 100 * 1) = 2e-3. The answer is (B). + - question: 'In an SR latch built from NOR gates, which condition is not allowed + + (A) S=0, R=0 (B) S=0, R=1 (C) S=1, R=0 (D) S=1, R=1' + target: Let's think step by step. An SR latch is a set-reset latch; in the case + where S=1 and R=1, the circuit has no stable state; instead a race condition + will be produced within the circuit, so the device will be in an undefined state. + So S=1, R=1 is an illegal question. The answer is (D). + - question: 'Two long parallel conductors carry 100 A. If the conductors are separated + by 20 mm, the force per meter of length of each conductor will be + + (A) 100 N. (B) 0.1 N. (C) 1 N. (D) 0.01 N.' + target: Let's think step by step. The magnetic force-per-length between two current-carrying + conductors is given by \mu_0 I_1 I_2 / (2 \pi r), where $r$ is the separation + distance and I_1 and I_2 are the currents. Plugging in 100 A for I_1 and I_2, + and 20 mm for r, gives 0.1 N. The answer is (B). + - question: "In a 2 pole lap winding dc machine , the resistance of one conductor is\ + \ 2\u03A9 and total number of conductors is 100. Find the total resistance\n\ + (A) 200\u03A9 (B) 100\u03A9 (C) 50\u03A9 (D) 10\u03A9" + target: 'Let''s think step by step. In lap winding, effectively two resistors + are connected in parallel, so the actual resistance of each pair is 1 Ohm. Since + we have 50 pairs, we get a total resistance of 50 Ohms. The answer is (C).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_electrical_engineering diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1fb9eef860424b49adf108566edadaf1500dbf5f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_elementary_mathematics.yaml @@ -0,0 +1,77 @@ +dataset_name: elementary_mathematics +description: The following are multiple choice questions (with answers) about elementary + mathematics. +fewshot_config: + sampler: first_n + samples: + - question: 'Olivia used the rule "Add 11" to create the number pattern shown below. + 10, 21, 32, 43, 54. Which statement about the number pattern is true? + + (A) The 10th number in the pattern will be an even number. + + (B) The number pattern will never have two even numbers next to each other. + + (C) The next two numbers in the pattern will be an even number then an odd number. + + (D) If the number pattern started with an odd number then the pattern would + have only odd numbers in it.' + target: Let's think step by step. Choice A is incorrect because every even-numbered + term in the pattern is odd, and 10 is an even number. Choice B is correct, because + adding an odd number (in this case 11) to an odd number produces an even number, + and adding an odd number to an even number produces an odd number. Thus the + terms in the pattern will alternate between odd and even, so there will never + be two even numbers next to each other. Choice C is incorrect because the last + term in the example is even (54), and we know that the terms will alternate + between even and odd. Choice D is incorrect because the terms in the pattern + will alternate between odd and even, regardless of the value of the first term. + The answer is (B). + - question: 'The population of the city where Michelle was born is 145,826. What is + the value of the 5 in the number 145,826? + + (A) 5 thousands + + (B) 5 hundreds + + (C) 5 tens + + (D) 5 ones' + target: 'Let''s think step by step. Choice A is correct, because there are three + digits following the 5, so + + the 5 is in the thousands place. Thus the other choices are incorrect. The answer + is (A).' + - question: "A store sells 107 different colors of paint. They have 25 cans of each\ + \ color in storage. The number of cans of paint the store has in storage can\ + \ be found using the expression below. 107 \xD7 25. How many cans of paint does\ + \ the store have in storage?\n(A) 749\n(B) 2,675\n(C) 2,945\n(D) 4,250" + target: Let's think step by step. We can calculate 107 x 25 = (100 x 25) + (7 + x 25) = 2500 + 175 = 2675. The answer is (B). + - question: 'A total of 30 players will play basketball at a park. There will be exactly + 5 players on each team. Which statement correctly explains how to find the number + of teams needed? + + (A) Add 5 to 30 to find 35 teams. + + (B) Divide 30 by 5 to find 6 teams. + + (C) Multiply 30 and 5 to find 150 teams. + + (D) Subtract 5 from 30 to find 25 teams.' + target: Let's think step by step. We want to find the number of teams. We know + that there are 5 players/team, and 30 players. Thus to get the number of teams + we divide players by players/team, so 30 players / 5 players/team = 6 teams. + The answer is (B). + - question: 'Which expression is equivalent to 5 x 9? + + (A) (5 x 4) x (6 x 5) + + (B) (5 x 5) + (5 x 4) + + (C) (5 x 5) + (5 x 9) + + (D) (5 x 9) x (6 x 9)' + target: 'Let''s think step by step. We know that 9 = (5 + 4), so 5 x 9 = 5 x (5 + + 4) = (5 x 5) + (5 x 4). The answer is (B).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_elementary_mathematics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fbc73c3d24f9cef06a41bbcfddea55aec1b424a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_formal_logic.yaml @@ -0,0 +1,70 @@ +dataset_name: formal_logic +description: The following are multiple choice questions (with answers) about formal + logic. +fewshot_config: + sampler: first_n + samples: + - question: "Which of the given formulas of PL is the best symbolization of the following\ + \ sentence?\nTurtles live long lives and are happy creatures, unless they are\ + \ injured.\n(A) (L \u2022 H) \u2261 I (B) (L \u2022 H) \u2228 I (C) L \u2022\ + \ (H \u2228 I) (D) L \u2022 (H \u2283 R)." + target: "Let's think step by step. We refer to Wikipedia articles on formal logic\ + \ for help. Let\u2019s solve this step by step. Let \u201CL\u201D denote \u201C\ + living long\u201D, H \u201Cbeing happy\u201D, and \u201CI\u201D \u201Cbeing\ + \ injured\u201D. Now, consider each choice:\n(A) means (living long AND being\ + \ happy) is equivalent to (being injured). \n(B) means (living long AND being\ + \ happy) OR (being injured). \n(C) means (living long) AND (being happy OR being\ + \ injured). \n(D) means (living long) AND (being happy implies being R), but\ + \ what R denotes is not clear.\nObviously, (B) is the best symbolization of\ + \ the original sentence. The answer is (B)." + - question: 'Select the best translation into predicate logic.George borrows Hector''s + lawnmower. (g: George; h: Hector; l: Hector''s lawnmower; Bxyx: x borrows y + from z). + + (A) Blgh (B) Bhlg (C) Bglh (D) Bghl' + target: "Let's think step by step. We refer to Wikipedia articles on formal logic\ + \ for help. Let\u2019s solve this step by step. We are told that \u201CBxyx\u201D\ + \ means \u201Cx borrows y from z\u201D. We can rewrite \u201CGeorge borrows\ + \ Hector's lawnmower\u201D as \u201CGeorge borrows a lawnmower from Hector\u201D\ + , which can then be translated into predicate logic as \u201CBglh\u201D. The\ + \ answer \u201CBglh\u201D appears in (C); therefore, (C) must be the correct\ + \ answer. The answer is (C)." + - question: "\nSelect the best English interpretation of the given arguments in predicate\ + \ logic.\nDm\n(\u2200x)(Wx \u2283 ~Dx). \n(\u2200x)Wx \u2228 Ag\t/ (\u2203x)Ax\n\ + (A) Marina is a dancer. Some weaklings are not dancers. Either everything is\ + \ a weakling or Georgia plays volleyball. So something plays volleyball. (B)\ + \ Marina is a dancer. No weakling is a dancer. Everything is either a weakling\ + \ or plays volleyball. So something plays volleyball. (C) Marina is a dancer.\ + \ Some weaklings are not dancers. Everything is either a weakling or plays volleyball.\ + \ So something plays volleyball. (D) Marina is a dancer. No weakling is a dancer.\ + \ Either everything is a weakling or Georgia plays volleyball. So something\ + \ plays volleyball." + target: "Let's think step by step. We refer to Wikipedia articles on formal logic\ + \ for help. Let\u2019s solve this step by step. Let \u201CD\u201D denote \u201C\ + being a dancer\u201D, \u201Cm\u201D denote \u201CMaria\u201D, \u201Cg\u201D\ + \ denote \u201CGeorgia\u201D, \u201CW\u201D denote \u201Cweakling\u201D, \u201C\ + A\u201D denote \u201Cplaying volleyball\u201D. Then, we have the following:\n\ + 1. Dm \u2192 Maria is a dance.\n2. (\u2200x)(Wx \u2283 ~Dx). \u2192 For all\ + \ x, if x is a weakling, then x is not a dancer. In other words, no weakling\ + \ is a dancer.\n3. (\u2200x)Wx \u2228 Ag\t/ (\u2203x)Ax \u2192 For all x, x\ + \ is a weakling or Georgia plays volleyball. So there exists an x that plays\ + \ volleyball. \nOptions (A) and (C) do claim that some weaklings are not dancers,\ + \ but the second argument strongly states that no weakling is a dancer. Thus,\ + \ we can eliminate them. Option (B) omits the important detail about Georgia\ + \ playing volleyball. Option (D) has all the details presented in the arguments\ + \ and is the best English interpretation of the arguments. The answer is (D)." + - question: "Select the best translation into predicate logic: No people drive on Mars.\n\ + (A) ~Pd (B) (\u2200x)(Px \u2228 ~Dx) (C) (\u2200x)(Px \u2283 ~Dx) (D) ~Dp" + target: "Let's think step by step. We refer to Wikipedia articles on formal logic\ + \ for help. Let\u2019s solve this step by step. Let \u201CP\u201D denote \u201C\ + being on Mars\u201D and \u201CD\u201D denote \u201Cdriving on Mars\u201D. Then\ + \ let\u2019s consider each option:\nOption (A): ~Pd \u2192 d is not on Mars.\n\ + Option (B): (\u2200x)(Px \u2228 ~Dx) \u2192 For all x, x is on Mars and x do\ + \ not drive on Mars.\nOption (C): (\u2200x)(Px \u2283 ~Dx) \u2192 For all x,\ + \ x is on Mars implies that x do not drive on Mars.\nOption (D): ~Dp: \u2192\ + \ p do not drive on Mars.\nOf all these options, Option (C) appears to be the\ + \ best and most meaningful interpretation of the argument \u201CNo people drive\ + \ on Mars.\u201D The answer is (C).\n\n" +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_formal_logic diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_global_facts.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..739c00106536c8df64213d55b83c053db1314124 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_global_facts.yaml @@ -0,0 +1,49 @@ +dataset_name: global_facts +description: The following are multiple choice questions (with answers) about global + facts. +fewshot_config: + sampler: first_n + samples: + - question: "As of 2017, how many of the world\u2019s 1-year-old children today have\ + \ been vaccinated against some disease? *\n(A) 80% (B) 60% (C) 40% (D) 20%" + target: Let's think step by step. We refer to Wikipedia articles on global facts + for help. According to data published by the World Health Organization, the + nummber of 1-year-old children vaccinated in 2017 exceeds 80%. The answer is + (A). + - question: 'As of 2019, about what percentage of Americans agree that the state is + run for the benefit of all the people? + + (A) 31% (B) 46% (C) 61% (D) 76%' + target: Let's think step by step. We refer to Wikipedia articles on global facts + for help. In 2019, about 46% percentage of Americans agree that the state is + run for the benefit of all the people. The answer is (B). + - question: 'As of 2019, about what percentage of Russians say it is very important + to have free media in our country without government/state censorship? + + (A) 38% (B) 53% (C) 68% (D) 83%' + target: Let's think step by step. We refer to Wikipedia articles on global facts + for help. As of 2019, about 38% of Russians say it is very important to have + free media in our country. The answer is (A). + - question: 'As of 2015, since 1990 forests have ____ in Europe and have ____ in Africa + and the Americas. + + (A) increased, increased (B) increased, decreased (C) decreased, increased (D) + decreased, decreased' + target: Let's think step by step. We refer to Wikipedia articles on global facts + for help. As of 2015, since 1990 forests have increased in Europe and have decreased + in Africa and the Americas. The answer is (B). + - question: 'Which of the following pairs of statements are both true (as of 2019)? + + (A) People tend to be optimistic about their own future and the future of their + nation or the world. (B) People tend to be optimistic about their own future + but pessimistic about the future of their nation or the world. (C) People tend + to be pessimistic about their own future but optimistic about the future of + their nation or the world. (D) People tend to be pessimistic about their own + future and the future of their nation or the world.' + target: 'Let''s think step by step. We refer to Wikipedia articles on global facts + for help. As of 2019, most people tend to be optimistic about their own future + but pessimistic about the future of their nation or the world. The answer is + (B).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_global_facts diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0dfb19f924761c6dd56cc6b3b9ada38b5bf473e0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_biology.yaml @@ -0,0 +1,69 @@ +dataset_name: high_school_biology +description: The following are multiple choice questions (with answers) about high + school biology. +fewshot_config: + sampler: first_n + samples: + - question: "In animal cells, which of the following represents the most likely pathway\ + \ that a secretory protein takes as it is synthesized in a cell?\n(A) Plasma\ + \ membrane\u2013Golgi apparatus\u2013ribosome\u2013secretory vesicle\u2013rough\ + \ ER (B) Ribosome\u2013Golgi apparatus\u2013rough ER\u2013secretory vesicle\u2013\ + plasma membrane (C) Plasma membrane\u2013Golgi apparatus\u2013ribosome\u2013\ + secretory vesicle\u2013rough ER (D) Ribosome\u2013rough ER\u2013Golgi apparatus\u2013\ + secretory vesicle\u2013plasma membrane" + target: Let's think step by step. Protein synthesis starts at the ribosome, so + we can eliminate (A) and (C). The ribosome is often in the endoplasmic reticulum + and moves from there to the Golgi apparatus, where it is modified and packaged + into a vesicle. The vesicle then floats to the plasma membrane and is secreted. + The answer is (D). + - question: "A mutation in a bacterial enzyme changed a previously polar amino acid\ + \ into a nonpolar amino acid. This amino acid was located at a site distant\ + \ from the enzyme\u2019s active site. How might this mutation alter the enzyme\u2019\ + s substrate specificity?\n(A) By changing the enzyme\u2019s pH optimum (B) By\ + \ changing the enzyme\u2019s location in the cell (C) By changing the shape\ + \ of the protein (D) An amino acid change away from the active site cannot alter\ + \ the enzyme\u2019s substrate specificity." + target: Let's think step by step. A change in an amino acid leads to a change + in the primary structure of the protein. A change in the primary structure may + lead to a change in the secondary and the tertiary structure of the protein. + A change in the tertiary structure means a change in the shape of the protein, + so (C) has to be correct. Since the change does not affect the active site of + the enzyme, we do not expect the activity of the enzyme to be affected. The + answer is (C). + - question: 'Which of the following is not a way to form recombinant DNA? + + (A) Translation (B) Conjugation (C) Specialized transduction (D) Transformation' + target: 'Let''s think step by step. The introduction of foreign DNA or RNA into + bacteria or eukaryotic cells is a common technique in molecular biology and + scientific research. There are multiple ways foreign DNA can be introduced into + cells including transformation, transduction, conjugation, and transfection. + In contrast, (A) is not a way to form DNA: during translation the ribosomes + synthesize proteins from RNA. The answer is (A).' + - question: 'Homologous structures are often cited as evidence for the process of natural + selection. All of the following are examples of homologous structures EXCEPT + + (A) the wings of a bird and the wings of a bat (B) the flippers of a whale and + the arms of a man (C) the pectoral fins of a porpoise and the flippers of a + seal (D) the forelegs of an insect and the forelimbs of a dog' + target: "Let's think step by step. \u200B\u200BHomologous structures are similar\ + \ physical features in organisms that share a common ancestor \u200B\u200Bbut\ + \ different functions. Comparisons (B) and (C) are clearly homologous because\ + \ they share a common ancestor and the structures serve different purposes.\ + \ Bat wings and birg wings are also homologous, while they are both wings, the\ + \ forelimbs serve different purposes. Insects and dogs are very far ancestors\ + \ since one is vertebrate while the other is invertebrate and the forelimbs\ + \ serve the same purpose, so they are not homologous. The answer is (D)." + - question: 'Which of the following is not known to be involved in the control of cell + division? + + (A) Cyclins (B) Protein kinases (C) Checkpoints (D) Fibroblast cells' + target: 'Let''s think step by step. Normal cells move through the cell cycle in + a regulated way. At the checkpoint stage, they use information about their own + internal state and cues from the environment around them to decide whether to + proceed with cell division. Cues like these act by changing the activity of + core cell cycle regulators inside the cell. The most common regulators are cyclins + and cyclin-dependent kinases. Fibroblast cells do not play any role in cell + division. The answer is (D).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_biology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..948c375824506cc188ed821c4512725bfdc92d06 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_chemistry.yaml @@ -0,0 +1,66 @@ +dataset_name: high_school_chemistry +description: The following are multiple choice questions (with answers) about high + school chemistry. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following is considered an acid anhydride? + + (A) HCl (B) H2SO3 (C) SO2 (D) Al(NO3)3' + target: Let's think step by step. An acid anhydride is a compound that is derived + by removing water from an acid. The chemical formula for water is H2O, which + means that we need to determine which of these options, when combined with H2O, + forms an acid. SO2, or Sulfur dioxide, when combined with H2O, makes H2SO4, + or sulfuric acid. The answer is (C). + - question: 'Which of the following is expected to be a polar molecule? + + (A) PCl4F (B) BF3 (C) CO2 (D) Si(CH3)4' + target: Let's think step by step. A polar molecule is one that has a slightly + positive charge on one end of the molecule and a slightly negative charge on + the other end. Boron trifluoride (BF3) has Boron as the center atom and three + fluorine atoms attached to it; it is trigonal planar and symmetric, so it is + nonpolar. Carbon Dioxide (CO2) has Carbon as the central atom with double bonds + to two Oxygen atoms - this is also symmetrical and therefore nonpolar. The same + is the case for tetramethyl silane (SI(CH3)4), which is a Silicon atom surrounded + by four methyl groups. The structure of PCL4F is that Phosphorus is the central + atom, attached to four chlorines and one fluorine atom. This is asymmetrical, + and therefore has a net dipole and is expected to be a polar molecule. The answer + is (A). + - question: 'From the solubility rules, which of the following is true? + + (A) All chlorides, bromides, and iodides are soluble (B) All sulfates are soluble + (C) All hydroxides are soluble (D) All ammonium-containing compounds are soluble' + target: Let's think step by step. The chlorides, bromides, and iodides of lead, + silver, and mercury are not soluble in water. This rules out (A). The sulfates + of lead, barium, and calcium are not soluble in water, which rules out (B). + The hydroxides of any metal besides sodium, potassium, ammonium, calcium, and + barium are insoluble. This rules out (C). Typically ammonium ions indicate a + soluble ionic substance. The answer is (D). + - question: 'A new compound is synthesized and found to be a monoprotic acid with a + molar mass of 248 g/mol. When 0.0050 mol of this acid are dissolved in 0.500 + L of water, the pH is measured as 3.89. What is the pKa of this acid? + + (A) 3.89 (B) 7.78 (C) 5.78 (D) 2.33' + target: "Let's think step by step. Recall that $[A] = [H^{+}]$. Here, this is\ + \ equal to $$10^{-3.89}$. Then we have $K_{a} = $\nrac{[H^{+}][A^{-}]}{[HA]}\ + \ = \nrac{10^{-3.89} \\cdot 10^{-3.89}}{10^{-2}}. The resulting exponent is\ + \ $-3.89 + (-3.89) - (-2) = 5.78$, therefore $K_a = 10^{-5.78}$. The $pK_a$\ + \ is the negative log of $K_a$, which is equal to $5.78$. The answer is (C)." + - question: 'A solution contains 2.00 mole of acetic acid, CH3COOH, and 1.00 mole of + calcium acetate, Ca(CH3COO)2. The solution is able to resist the addition of + a small amount of strong acid or strong base with only minor changes in the + pH of the solution. Larger quantities of strong acid or strong base can cause + a significant change in pH. How many moles of nitric acid, HNO3, may be added + before the pH begins to change significantly? + + (A) 0.500 mole (B) 1.00 mole (C) 2.00 mole (D) 3.00 mole' + target: "Let's think step by step. We would like to compute the buffer capacity\ + \ of this solution. First we write the equation for the ionization of the weak\ + \ acid, in this case of acetic acid. $CH_{3}COOH (aq) + H_{2}O \nightarrow H_{3}O^{+}\ + \ + CH3COO^{-}$. The conjugate base is therefore the acetate ion. The added\ + \ strong acid, Nitric acid, will react with the conjugate base. Therefore the\ + \ maximum amount of acid that can be added will be equal to the amount of acetate\ + \ ion, or 2 moles. The answer is (C).\n\n" +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_chemistry diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f21030ae880ec0d6d42f0e2618c312b55b82549 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_computer_science.yaml @@ -0,0 +1,84 @@ +dataset_name: high_school_computer_science +description: The following are multiple choice questions (with answers) about high + school computer science. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following is an example of the use of a device on the Internet + of Things (IoT) ? + + (A) A car alerts a driver that it is about to hit an object. (B) A hiker uses + a G P S watch to keep track of her position. (C) A refrigerator orders milk + from an online delivery service when the milk in the refrigerator is almost + gone. (D) A runner uses a watch with optical sensors to monitor his heart rate.' + target: Let's think step by step. The term Internet of Things (IoT) refers to + common devices which are connected to the internet, enabling new functionality. + Choice A is incorrect because it does not describe an internet connected device. + In choice B, the watch is only described as having GPS functionality but no + internet connectivity. Choice C describes a common device (a refrigerator) which + has internet connectivity enabling new functionality (online ordering). Choice + D does not mention internet connectivity for the watch, only optical sensors. + The answer is (C). + - question: 'Many Web browsers allow users to open anonymous windows. During a browsing + session in an anonymous window, the browser does not record a browsing history + or a list of downloaded files. When the anonymous window is exited, cookies + created during the session are deleted. Which of the following statements about + browsing sessions in an anonymous window is true? + + (A) The activities of a user browsing in an anonymous window will not be visible + to people who monitor the user''s network, such as the system administrator. + (B) Items placed in a Web store''s shopping cart for future purchase during + the anonymous browsing session will not be saved on the user''s computer. (C) + A user will not be able to log in to e-mail or social media accounts during + the anonymous browsing session. (D) A user browsing in an anonymous window will + be protected from viruses launched from any web sites visited or files downloaded.' + target: "Let's think step by step. Choice A is incorrect as it only describes\ + \ network traffic, which an anonymous browser does not change. Choice B is correct\ + \ as it correctly describes how an anonymous browser will prevent saving data\ + \ on the user\u2019s computer after the session is ended. Choice C is incorrect\ + \ because an anonymous browser will not prevent logging in to email or social\ + \ media accounts. Choice D is incorrect because an anonymous browser in itself\ + \ performs no virus protection. The answer is (B)." + - question: "In the program below, the initial value of X is 5 and the initial value\ + \ of Y is 10.\nIF (X < 0){\n DISPLAY (\"Foxtrot\")\n} ELSE {\n IF (X > Y){\n\ + \ DISPLAY (\"Hotel\")\n } ELSE {\n IF (Y > 0){\n DISPLAY (\"November\")\n\ + \ } ELSE {\n DISPLAY (\"Yankee\")\n }\n }\n}\nWhat is displayed as a result\ + \ of running the program?\n(A) Foxtrot (B) Hotel (C) November (D) Yankee" + target: Let's think step by step. Because X has the value 5, the first conditional + IF (X < 0) is false, so we move to the first ELSE clause. Because X is 5 and + Y is 10, the second conditional IF (X > Y) is false, so we move to the following + ELSE clause. Since Y is 10, the conditional IF (Y > 0) is true, so the command + DISPLAY ("November") is executed. The answer is (C). + - question: 'What is the output of "abc"[::-1] in Python 3? + + (A) Error (B) abc (C) cba (D) c' + target: Let's think step by step. We know that the slicing operator [::-1] takes + all of the elements in the string in reverse order, so we reverse the order + of the string "abc", resulting in "cba". The answer is (C). + - question: "A list of numbers has n elements, indexed from 1 to n. The following algorithm\ + \ is intended to display the number of elements in the list that have a value\ + \ greater than 100. The algorithm uses the variables count and position. Steps\ + \ 3 and 4 are missing.\n Step 1: Set count to 0 and position to 1.\n Step 2:\ + \ If the value of the element at index position is greater than 100, increase\ + \ the value of count by 1.\n Step 3: (missing step)\n Step 4: (missing step)\n\ + \ Step 5: Display the value of count.\nWhich of the following could be used\ + \ to replace steps 3 and 4 so that the algorithm works as intended?\n(A) Step\ + \ 3: Increase the value of position by 1.\n Step 4: Repeat steps 2 and 3 until\ + \ the value of count is greater than 100.\n(B) Step 3: Increase the value of\ + \ position by 1.\n Step 4: Repeat steps 2 and 3 until the value of position\ + \ is greater than n.\n(C) Step 3: Repeat step 2 until the value of count is\ + \ greater than 100.\n Step 4: Increase the value of position by 1.\n(D) Step\ + \ 3: Repeat step 2 until the value of position is greater than n.\n Step 4:\ + \ Increase the value of count by 1." + target: 'Let''s think step by step. Choice A is incorrect, because its Step 4 + has an incorrect termination condition, stopping when count is greater than + 100. We need to stop after inspecting all elements in the list. Choice B is + correct because it correctly increments both count and position, and correctly + repeats these steps and terminates when all elements in the list have been inspected. + Choice C is incorrect because it incorrectly increments the variable count until + its value is greater than 100, regardless of the elements in the list. Choice + D is incorrect because its step 3 does not increment the value of position, + so it will repeat forever. The answer is (B).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_computer_science diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4707a1857f44fad0ef67313b5f1901b5b6c869b6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_european_history.yaml @@ -0,0 +1,199 @@ +dataset_name: high_school_european_history +description: The following are multiple choice questions (with answers) about high + school european history. +fewshot_config: + sampler: first_n + samples: + - question: 'This question refers to the following information. + + Albeit the king''s Majesty justly and rightfully is and ought to be the supreme + head of the Church of England, and so is recognized by the clergy of this realm + in their convocations, yet nevertheless, for corroboration and confirmation + thereof, and for increase of virtue in Christ''s religion within this realm + of England, and to repress and extirpate all errors, heresies, and other enormities + and abuses heretofore used in the same, be it enacted, by authority of this + present Parliament, that the king, our sovereign lord, his heirs and successors, + kings of this realm, shall be taken, accepted, and reputed the only supreme + head in earth of the Church of England, called Anglicans Ecclesia; and shall + have and enjoy, annexed and united to the imperial crown of this realm, as well + the title and style thereof, as all honors, dignities, preeminences, jurisdictions, + privileges, authorities, immunities, profits, and commodities to the said dignity + of the supreme head of the same Church belonging and appertaining; and that + our said sovereign lord, his heirs and successors, kings of this realm, shall + have full power and authority from time to time to visit, repress, redress, + record, order, correct, restrain, and amend all such errors, heresies, abuses, + offenses, contempts, and enormities, whatsoever they be, which by any manner + of spiritual authority or jurisdiction ought or may lawfully be reformed, repressed, + ordered, redressed, corrected, restrained, or amended, most to the pleasure + of Almighty God, the increase of virtue in Christ''s religion, and for the conservation + of the peace, unity, and tranquility of this realm; any usage, foreign land, + foreign authority, prescription, or any other thing or things to the contrary + hereof notwithstanding. + + English Parliament, Act of Supremacy, 1534 + + From the passage, one may infer that the English Parliament wished to argue + that the Act of Supremacy would + + (A) give the English king a new position of authority (B) give the position + of head of the Church of England to Henry VIII alone and exclude his heirs (C) + establish Calvinism as the one true theology in England (D) end various forms + of corruption plaguing the Church in England' + target: Let's think step by step. We refer to Wikipedia articles on european history + for help. The Act of Supremacy states that it grants authority to the king "to + repress and extirpate all errors, heresies, and other enormities and abuses", + referring to the corruption in the Church of England. The answer is (D). + - question: "This question refers to the following information.\nRead the following\ + \ excerpt.\nThe revolutionary seed had penetrated into every country and spread\ + \ more or less. It was greatly developed under the r\xE9gime of the military\ + \ despotism of Bonaparte. His conquests displaced a number of laws, institutions,\ + \ and customs; broke through bonds sacred among all nations, strong enough to\ + \ resist time itself; which is more than can be said of certain benefits conferred\ + \ by these innovators.\nThe monarchs will fulfil the duties imposed upon them\ + \ by Him who, by entrusting them with power, has charged them to watch over\ + \ the maintenance of justice, and the rights of all, to avoid the paths of error,\ + \ and tread firmly in the way of truth. Placed beyond the passions which agitate\ + \ society, it is in days of trial chiefly that they are called upon to despoil\ + \ realities of their false appearances, and to show themselves as they are,\ + \ fathers invested with the authority belonging by right to the heads of families,\ + \ to prove that, in days of mourning, they know how to be just, wise, and therefore\ + \ strong, and that they will not abandon the people whom they ought to govern\ + \ to be the sport of factions, to error and its consequences, which must involve\ + \ the loss of society.\nUnion between the monarchs is the basis of the policy\ + \ which must now be followed to save society from total ruin. . . .\nLet them\ + \ not confound concessions made to parties with the good they ought to do for\ + \ their people, in modifying, according to their recognized needs, such branches\ + \ of the administration as require it.\nLet them be just, but strong; beneficent,\ + \ but strict.\nLet them maintain religious principles in all their purity, and\ + \ not allow the faith to be attacked and morality interpreted according to the\ + \ social contract or the visions of foolish sectarians.\nLet them suppress Secret\ + \ Societies; that gangrene of society.\n\u2014Klemens von Metternich, Political\ + \ Confession of Faith, 1820\nWhich of the following was the greatest cause of\ + \ the fears expressed by Metternich in the document above?\n(A) The ideas of\ + \ personal liberty and nationalism conceived during the Enlightenment resulted\ + \ in radical revolutions that could spread throughout Europe. (B) The conquest\ + \ of Europe by Napoleon led to the creation of new factions and shifted the\ + \ European balance of power. (C) The power of monarchs had grown to the point\ + \ where it needed to be checked by other powers within each nation or domination\ + \ of civilians would occur. (D) The rising and falling economic cycle of the\ + \ newly emerging capitalist economy could lead to civilian unrest that must\ + \ be suppressed." + target: Let's think step by step. We refer to Wikipedia articles on european history + for help. The fears of revolution in early 19th century Europe expressed by + Klemens von Metternich, a conservative Austrian statesman, were a direct result + of the age of Enlightenment, a period of European history where the absolute + power of the monarchy was challenged with ideas of individual liberty and nationalism, + leading to the French revolution and its effects all over Europe. The answer + is (A). + - question: 'This question refers to the following information. + + The excerpts below are from the Navigation Acts of 1651. + + [A]fter the first day of December, one thousand six hundred fifty and one, and + from thence forwards, no goods or commodities whatsoever of the growth, production + or manufacture of Asia, Africa or America, or of any part thereof; or of any + islands belonging to them, or which are described or laid down in the usual + maps or cards of those places, as well of the English plantations as others, + shall be imported or brought into this Commonwealth of England, or into Ireland, + or any other lands, islands, plantations, or territories to this Commonwealth + belonging, or in their possession, in any other ship or ships, vessel or vessels + whatsoever, but only in such as do truly and without fraud belong only to the + people of this Commonwealth, or the plantations thereof, as the proprietors + or right owners thereof; and whereof the master and mariners are also of the + people of this Commonwealth, under the penalty of the forfeiture and loss of + all the goods that shall be imported contrary to this act, , , , + + [N]o goods or commodities of the growth, production, or manufacture of Europe, + or of any part thereof, shall after the first day of December, one thousand + six hundred fifty and one, be imported or brought into this Commonwealth of + England, or any other lands or territories to this Commonwealth belonging, or + in their possession, in any ship or ships, vessel or vessels whatsoever, but + in such as do truly and without fraud belong only to the people of this Commonwealth, + and in no other, except only such foreign ships and vessels as do truly and + properly belong to the people of that country or place, of which the said goods + are the growth, production or manufacture. + + Which of the following best describes the outcome of the Navigation Acts of + 1651? + + (A) They served as a catalyst for the growth of English shipping and overseas + trade, but did little to limit the prospects of the Dutch in the seventeenth + century. (B) They brought about almost immediate hardships for the Dutch economy + as their dominance of overseas trade quickly ended. (C) They were rescinded + during the restoration of the Stuarts as they sought normal diplomatic relations + with the Dutch so not as to need Parliament''s financial support for war. (D) + They led to nearly a century of recurrent war between England and the Netherlands, + which would not end until after American independence.' + target: Let's think step by step. We refer to Wikipedia articles on european history + for help. The Navigation Acts of 1651 helped English shipping by restricting + the ability of ships from other European countries, especially the Dutch, to + transport goods from colonies in Asia and Africa into England. The answer is + (A). + - question: "This question refers to the following information.\nIn Russia there was\ + \ nothing going on well, and [Souvarine] was in despair over the news he had\ + \ received. His old companions were all turning to the politicians; the famous\ + \ Nihilists who made Europe tremble-sons of village priests, of the lower middle\ + \ class, of tradesmen-could not rise above the idea of national liberation,\ + \ and seemed to believe that the world would be delivered-when they had killed\ + \ their despot&\u2026\n\"Foolery! They'll never get out of it with their foolery.\"\ + \nThen, lowering his voice still more, in a few bitter words he described his\ + \ old dream of fraternity. He had renounced his rank and his fortune; he had\ + \ gone among workmen, only in the hope of seeing at last the foundation of a\ + \ new society of labour in common. All the sous in his pockets had long gone\ + \ to the urchins of the settlement; he had been as tender as a brother with\ + \ the colliers, smiling at their suspicion, winning them over by his quiet workmanlike\ + \ ways and his dislike of chattering. But decidedly the fusion had not taken\ + \ place.\nHis voice changed, his eyes grew bright, he fixed them on \xE9tienne,\ + \ directly addressing him:\n\"Now, do you understand that? These hatworkers\ + \ at Marseilles who have won the great lottery prize of a hundred thousand francs\ + \ have gone off at once and invested it, declaring that they are going to live\ + \ without doing anything! Yes, that is your idea, all of you French workmen;\ + \ you want to unearth a treasure in order to devour it alone afterwards in some\ + \ lazy, selfish corner. You may cry out as much as you like against the rich,\ + \ you haven't got courage enough to give back to the poor the money that luck\ + \ brings you. You will never be worthy of happiness as long as you own anything,\ + \ and your hatred of the bourgeois proceeds solely from an angry desire to be\ + \ bourgeois yourselves in their place.\"\n\xE9mile Zola, French writer, Germinal,\ + \ 1885\nThe passage displays the direct concern for the welfare of the working\ + \ classes that was typically a part of which movement?\n(A) Capitalist (B) Scientific\ + \ (C) Communist (D) Existentialist" + target: Let's think step by step. We refer to Wikipedia articles on european history + for help. The modern Communist movement aims to establish a classless society + based on communal ownership and distribution of property and means of production, + thereby especially benefiting the working classes. The answer is (C). + - question: "This question refers to the following information.\nThe following excerpt\ + \ is from a pamphlet.\nYou will do me the justice to remember, that I have always\ + \ strenuously supported the Right of every man to his own opinion, however different\ + \ that opinion might be to mine. He who denies to another this right, makes\ + \ a slave of himself to his present opinion, because he precludes himself the\ + \ right of changing it.\nThe most formidable weapon against errors of every\ + \ kind is Reason. I have never used any other, and I trust I never shall.\n\ + The circumstance that has now taken place in France of the total abolition of\ + \ the whole national order of priesthood, and of everything appertaining to\ + \ compulsive systems of religion, and compulsive articles of faith, has not\ + \ only precipitated my intention, but rendered a work of this kind exceedingly\ + \ necessary, lest in the general wreck of superstition, of false systems of\ + \ government, and false theology, we lose sight of morality, of humanity, and\ + \ of the theology that is true.\nI believe in one God, and no more; and I hope\ + \ for happiness beyond this life.\nI believe in the equality of man; and I believe\ + \ that religious duties consist in doing justice, loving mercy, and endeavoring\ + \ to make our fellow-creatures happy.\nI do not believe in the creed professed\ + \ by the Jewish church, by the Roman church, by the Greek church, by the Turkish\ + \ church, by the Protestant church, nor by any church that I know of. My own\ + \ mind is my own church.\nAll national institutions of churches, whether Jewish,\ + \ Christian or Turkish, appear to me no other than human inventions, set up\ + \ to terrify and enslave mankind, and monopolize power and profit.\nI do not\ + \ mean by this declaration to condemn those who believe otherwise; they have\ + \ the same right to their belief as I have to mine.\n\u2014Thomas Paine, The\ + \ Age of Reason, 1794\u20131795\nWhich of the following Enlightenment philosophes\ + \ designed a system of checks and balances for government to avoid abuses of\ + \ power?\n(A) Jean Jacques Rousseau (B) Baron Montesquieu (C) Mary Wollstonecraft\ + \ (D) Adam Smith" + target: 'Let''s think step by step. We refer to Wikipedia articles on european + history for help. Baron Montesquieu was a 18th centrury French philsopher who + wrote extensively against the monoplization of power and advocated for a system + of checks and balances in government to prevent the rise of despotism. The answer + is (B).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_european_history diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96f4b365af04a3dc5c754def2899a5291ada1072 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_geography.yaml @@ -0,0 +1,53 @@ +dataset_name: high_school_geography +description: The following are multiple choice questions (with answers) about high + school geography. +fewshot_config: + sampler: first_n + samples: + - question: 'Which one of the following items is an example of nonmaterial culture? + + (A) Dove soap (B) Dove candy bar (C) Dove symbol (D) A dove (bird).' + target: Let's think step by step. We refer to Wikipedia articles on geography + for help. Nonmaterial culture consists of cultural ideas, beliefs or symbols + that are not physical objects. The answer is (C). + - question: 'During the third stage of the demographic transition model, which of the + following is true? + + (A) Birth rates increase and population growth rate is less rapid. (B) Birth + rates decline and population growth rate is less rapid. (C) Birth rates increase + and population growth rate increases. (D) Birth rates decrease and population + growth rate increases.' + target: Let's think step by step. We refer to Wikipedia articles on geography + for help. The demographic transition model models the five different stages + of population growth as a country goes through economic development, where the + third stage refers to a period of declining birth rates and lower population + growth. The answer is (B). + - question: 'The practice of hiring a foreign third-party service provider to run an + operation is called + + (A) outsourcing. (B) offshoring. (C) maquiladoras. (D) locational interdependence.' + target: Let's think step by step. We refer to Wikipedia articles on geography + for help. "Offshoring" literally means to move or base some of the activities + or processes of a company to a foreign country. The answer is (B). + - question: 'Which of the following statements is NOT accurate regarding the services + provided by local governments in the United States? + + (A) Duplication of efforts occurs often. (B) Social problems of the central + city spill over into the surrounding residential suburbs. (C) Inefficiency in + providing services occurs often. (D) One neighborhood''s efforts to reduce pollution + are always supported by neighboring communities.' + target: Let's think step by step. We refer to Wikipedia articles on geography + for help. There may be economic, social or political reasons for two neighboring + communities and their local governments not agreeing to pollution reduction + efforts initiated by one of them. The answer is (D). + - question: 'The rate of natural increase of a population is found by subtracting the + + (A) crude death rate from the crude birth date. (B) crude birth rate from the + crude death rate. (C) doubling time from the crude birth rate. (D) fertility + rate from the crude death rate.' + target: 'Let''s think step by step. We refer to Wikipedia articles on geography + for help. The difference between number of births and deaths gives the population + increase at any given time. The answer is (A).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_geography diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c11772183c0266288f95f4d273ebbdf32c0dba1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,61 @@ +dataset_name: high_school_government_and_politics +description: The following are multiple choice questions (with answers) about high + school government and politics. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following best states an argument made by James Madison in + The Federalist number 10? + + (A) Honest politicians can prevent factions from developing. (B) Factions are + more likely to occur in large republics than in small ones. (C) The negative + effects of factionalism can be reduced by a republican government. (D) Free + elections are the people''s best defense against factionalism.' + target: Let's think step by step. We refer to Wikipedia articles on government + and politics for help. In the Federalist number 10, James Madison advocated + for a representative republican form of government to guard against factionalism. + The answer is (C). + - question: 'The term "budget deficit" refers to the + + (A) annual increase in federal spending on the military (B) amount of interest + on the national debt (C) difference between the initial budget proposals made + by the president and Congress (D) amount the government spends in excess of + its revenues' + target: Let's think step by step. We refer to Wikipedia articles on government + and politics for help. When the goverment spends more than it earns, their difference + is the budget deficit. The answer is (D). + - question: 'Which of the following statements about cabinet departments is FALSE? + + (A) They are established by the legislative branch. (B) Their members often + don''t have much influence over presidential decisions. (C) They cannot all + be run by leaders who belong to the same political party the president does. + (D) Not every federal agency is a cabinet department.' + target: Let's think step by step. We refer to Wikipedia articles on government + and politics for help. There is no law stipulating that some cabinet department + leaders have to belong to a political party different from that of the president. + The answer is (C). + - question: 'Which of the following cases established the precedent that a defendant + must be informed of the right to remain silent, the right to a lawyer, and protection + from self-incrimination? + + (A) Weeks v. United States (B) Betts v. Brady (C) Mapp v. Ohio (D) Miranda v. + Arizona' + target: Let's think step by step. We refer to Wikipedia articles on government + and politics for help. In the landmark Miranda v. Arizona in 1966, the US Supreme + Court, based on the Fifth and Sixth Amendment of the US Constitution, guaranteed + a defendant's right to an attorney and protection from self-incrimination. The + answer is (D). + - question: 'Uncertainty over the limits to presidential power is caused primarily + by the fact that + + (A) the constitutional definition of those powers is broad and unspecific (B) + most people agree that the Constitution places too many limits on presidential + power (C) the Supreme Court consistently refuses to rule on cases concerning + presidential powers (D) constitutional amendments have greatly increased presidential + powers' + target: 'Let''s think step by step. We refer to Wikipedia articles on government + and politics for help. The US Constitution is not very specific about the powers + of the president, leading to uncertainty over its limits. The answer is (A).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_government_and_politics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5700d5df946539608bbf10555e0f06ca07672b09 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,53 @@ +dataset_name: high_school_macroeconomics +description: The following are multiple choice questions (with answers) about high + school macroeconomics. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following policies best describes supply-side fiscal policy? + + (A) An increase in the money supply (B) Increased government spending (C) Lower + taxes on research and development of new technology (D) Higher taxes on household + income' + target: Let's think step by step. We refer to Wikipedia articles on macroeconomics + for help. Supply-side fiscal policy stimulates the economy by encouraging more + production of goods and services through reduction in taxes and deregulation. + The answer is (C). + - question: 'The short-run Phillips curve indicates a + + (A) direct relation between unemployment and inflation (B) direct relation between + price and quantity demanded (C) inverse relation between price and quantity + demanded (D) inverse relation between unemployment and inflation' + target: Let's think step by step. We refer to Wikipedia articles on macroeconomics + for help. The short-run Phillips curve shows that whenever unemployment decreases + below a natural level, the inflation starts increasing, and vice-versa. The + answer is (D). + - question: 'Holding all else equal which of the following monetary policies would + be used to boost U.S. exports? + + (A) Increasing the discount rate (B) Increasing the reserve ratio (C) Buying + government securities (D) Lowering tariffs' + target: Let's think step by step. We refer to Wikipedia articles on macroeconomics + for help. Buying government securities leads to reduction in demand for US dollars + from foreign buyers, thereby making it cheaper and hence making US exports more + attractive. The answer is (C). + - question: 'A federal deficit occurs when + + (A) exports exceed imports. (B) imports exceed exports. (C) federal tax collections + exceed spending. (D) federal spending exceeds federal tax revenues.' + target: Let's think step by step. We refer to Wikipedia articles on macroeconomics + for help. A federal deficit occurs when federal spending exceeds federal income + which is primarily from tax revenues. The answer is (D). + - question: 'Which of the following is not included in the U.S. GDP? + + (A) The U.S. military opens a new base in a foreign country with 1000 U.S. personnel. + (B) Japanese consumers buy thousands of CDs produced in the United States. (C) + An American pop singer performs a sold-out concert in Paris. (D) A French theatrical + production tours dozens of American cities.' + target: 'Let''s think step by step. We refer to Wikipedia articles on macroeconomics + for help. The economic transactions related to the performance of the American + pop-singer in Paris happens entirely outside the U.S. and hence is not included + in the GDP numbers. The answer is (C).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_macroeconomics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3e05795561a11dca35dd342bc88b4794584809c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_mathematics.yaml @@ -0,0 +1,51 @@ +dataset_name: high_school_mathematics +description: The following are multiple choice questions (with answers) about high + school mathematics. +fewshot_config: + sampler: first_n + samples: + - question: 'Simplify and write the result with a rational denominator: $$\sqrt{\sqrt[3]{\sqrt{\frac{1}{729}}}}$$ + + (A) \frac{3\sqrt{3}}{3} (B) \frac{1}{3} (C) \sqrt{3} (D) \frac{\sqrt{3}}{3}' + target: Let's think step by step. Factoring $729=3^6$ and combining the roots + $\frac{1}{2}\frac{1}{3}\frac{1}{2}=\frac{1}{12}$, we get that $\sqrt{\sqrt[3]{\sqrt{\frac{1}{729}}}}=\left(\frac{1}{3^6}\right)^{\frac{1}{12}}=\frac{1}{3^{\frac{1}{2}}}=\frac{3}{\sqrt{3}}$ + The answer is (D). + - question: 'Five thousand dollars compounded annually at an $x\%$ interest rate takes + six years to double. At the same interest rate, how many years will it take + $\$300$ to grow to $\$9600$? + + (A) 12 (B) 1 (C) 30 (D) 5' + target: Let's think step by step. To go from $\$300$ to $\$9600$, the value must + go up by a factor of $9600/300=32=2^5$. Since at this interest rate it takes + six years for it to double, it will take $5*6=30$ years to grow to $\$9600$. + The answer is (C). + - question: "Ten students take a biology test and receive the following scores: 45,\ + \ 55, 50, 70, 65, 80, 40, 90, 70, 85. What is the mean of the students\u2019\ + \ test scores?\n(A) 55 (B) 60 (C) 62 (D) 65" + target: Let's think step by step. There are 10 students and the sum of their scores + is $45 + 55 + 50 + 70 + 65 + 80 + 40 + 90 + 70 + 85 = 650$, the mean is $650/10=65$. + The answer is (D). + - question: 'The variable $x$ varies directly as the square of $y$, and $y$ varies + directly as the cube of $z$. If $x$ equals $-16$ when $z$ equals 2, what is + the value of $x$ when $z$ equals $\frac{1}{2}$? + + (A) -1 (B) 16 (C) -\frac{1}{256} (D) \frac{1}{16}' + target: Let's think step by step. We know that $x \propto y^2$ and $y \propto + z^3$, so $x = k z^6$ for some constant $k$. Plugging in for $x=-16$ and $z=2$, + the constant value is $k=\frac{x}{z^6}=\frac{-16}{64}=-\frac{1}{4}$. So, when + $z=\frac{1}{2}$, the value of $x$ is $x=kz^6=-\frac{1}{4}\frac{1}{2^6}=-\frac{1}{256}$. + The answer is (C). + - question: 'Joe was in charge of lights for a dance. The red light blinks every two + seconds, the yellow light every three seconds, and the blue light every five + seconds. If we include the very beginning and very end of the dance, how many + times during a seven minute dance will all the lights come on at the same time? + (Assume that all three lights blink simultaneously at the very beginning of + the dance.) + + (A) 3 (B) 15 (C) 6 (D) 5' + target: 'Let''s think step by step. The least common multiple of 2, 3 and 5 is + 30, so during a 7 minute dance, all the three lights will come on at the same + time $2*7+1=15$ times. The answer is (B).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_mathematics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ead78be898187a7c08c292eaeebfe6b067ef4413 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,56 @@ +dataset_name: high_school_microeconomics +description: The following are multiple choice questions (with answers) about high + school microeconomics. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following is necessarily a characteristic of oligopoly? + + (A) Free entry into and exit from the market (B) A few large producers (C) One + producer of a good with no close substitutes (D) A homogenous product' + target: Let's think step by step. We refer to Wikipedia articles on microeconomics + for help. An oligopoly is when a market is dominated by just one or a few number + of sellers or producers. To get oligopoly, the market should have high barriers + to new entry, and the product has differentiation. The answer is (B). + - question: 'If the government subsidizes producers in a perfectly competitive market, + then + + (A) the demand for the product will increase (B) the demand for the product + will decrease (C) the consumer surplus will increase (D) the consumer surplus + will decrease' + target: Let's think step by step. We refer to Wikipedia articles on microeconomics + for help. (A) and (B) are wrong because the demand curve does not change at + all. If the government subsidizes producers, the supply will increase, and thus + the consumer surplus also increases. The answer is (C). + - question: 'Which of the following is true of a price floor? + + (A) The price floor shifts the demand curve to the left. (B) An effective floor + creates a shortage of the good. (C) The price floor shifts the supply curve + of the good to the right. (D) To be an effective floor, it must be set above + the equilibrium price.' + target: Let's think step by step. We refer to Wikipedia articles on microeconomics + for help. Price floor does not shift the demand or shift curve. An effective + price floor should be set above the equilibrium price, otherwise the market + bears and the floor does not have effective effect. The answer is (D). + - question: 'The concentration ratio for a monopoly is + + (A) 0 (B) 5 (C) 10 (D) 100' + target: Let's think step by step. We refer to Wikipedia articles on microeconomics + for help. The concentration ratio is calculated as the sum of market share of + a specific number of largest companies. Monopoly means one company or entity + controls the entire market, therefore, the concentration ratio is 100 percent. + The answer is (D). + - question: 'In a competitive labor market for housepainters, which of the following + would increase the demand for housepainters? + + (A) An effective minimum wage imposed on this labor market. (B) An increase + in the price of gallons of paint. (C) An increase in the construction of new + houses. (D) An increase in the price of mechanical painters so long as the output + effect exceeds the substitution effect.' + target: 'Let''s think step by step. We refer to Wikipedia articles on microeconomics + for help. An increase in the construction of new houses means an increase demand + of in-house painting, thus increases the demand for housepainters. The answer + is (C).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_microeconomics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d2166b8595529213797cf12efd67f90ca4e9e61 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_physics.yaml @@ -0,0 +1,50 @@ +dataset_name: high_school_physics +description: The following are multiple choice questions (with answers) about high + school physics. +fewshot_config: + sampler: first_n + samples: + - question: 'A microwave oven is connected to an outlet, 120 V, and draws a current + of 2 amps. At what rate is energy being used by the microwave oven? + + (A) 10 W (B) 30 W (C) 60 W (D) 240 W' + target: Let's think step by step. Rate of energy usage is known as power; in an + dissipative electrical circuit, power is given by voltage times current. So + in our case, the power is 120 V times 2 amps, or 240 W. The answer is (D). + - question: "A point charge, Q = +1 mC, is fixed at the origin. How much work is required\ + \ to move a charge, Q = +8 \xB5C, from the point (0, 4 meters) to the point\ + \ (3 meters, 0)?\n(A) 3.5 J (B) 6.0 J (C) 22.5 J (D) 40 J" + target: "Let's think step by step. To calculate the work required to move a charge\ + \ from one location to another in a fixed electric field, it is enough to calculate\ + \ the potential difference between the two locations. Here, the potential only\ + \ depends on the distance between the charges; it\u2019s $k q_1 q_2 / r$, where\ + \ $k$ is Coulomb\u2019s constant. Plugging in values $q_1 = $ 1 mC, $q_2 = 8\ + \ \\mu$ C, gives the answer as 5.992 J, which rounds to 6 J. The answer is (B)." + - question: 'Which of the following conditions will ensure that angular momentum is + conserved? I. Conservation of linear momentum II. Zero net external force III. + Zero net external torque + + (A) I and II only (B) I and III only (C) II and III only (D) III only' + target: Let's think step by step. Torque is defined as the change in angular momentum; + if there is zero external torque, angular momentum is conserved. The answer + is (D). + - question: "A photocell of work function \u03D5 = 2eV is connected to a resistor in\ + \ series. Light of frequency f = 1 \xD7 10^15 Hz hits a metal plate of the photocell.\ + \ If the power of the light is P = 100 W, what is the current through the resistor?\n\ + (A) 2:00 AM (B) 6:00 AM (C) 12:00 AM (D) 24 A" + target: Let's think step by step. The only answer above which has units of current + is D, 24 A. The answer is (D). + - question: "A pipe full of air is closed at one end. A standing wave is produced in\ + \ the pipe, causing the pipe to sound a note. Which of the following is a correct\ + \ statement about the wave\u2019s properties at the closed end of the pipe?\n\ + (A) The pressure is at a node, but the particle displacement is at an antinode.\ + \ (B) The pressure is at an antinode, but the particle displacement is at a\ + \ node. (C) The pressure and the particle displacement are both at nodes. (D)\ + \ The pressure and the particle displacement are both at antinodes." + target: 'Let''s think step by step. At the closed end of the pipe, the particles + cannot have any net displacement because the pipe closure stops them. So the + particle displacement is at a node. This closure also causes the pressure to + be maximal, i.e. an antinode. The answer is (B).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_physics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..553fe18de24f14765d14d0fb2dbdfaf72449af32 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_psychology.yaml @@ -0,0 +1,64 @@ +dataset_name: high_school_psychology +description: The following are multiple choice questions (with answers) about high + school psychology. +fewshot_config: + sampler: first_n + samples: + - question: 'Pascale is interested in the processing strategies children use to learn + new information. Pascale would best be classified as what type of psychologist? + + (A) sociocultural (B) clinical (C) cognitive (D) behaviorist' + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. Sociocultural psychologist focuses on the effect of societal factors + on people. Clinical psychologist focuses on people with mental issues. Cognitive + psychologist focuses on how people think and learn, including the processing + strategies. Behaviorist focuses more on the environment and experience effect + on people. The answer is (C). + - question: 'According to Caplan''s model of consultee-centered case consultation, + the consultant is primarily interested in + + (A) identifying the causes and solutions of the client''s presenting problems + (B) identifying and eliminating the causes of the consultee''s difficulties + in handling a problem (C) establishing a hierarchy of authority to enable effective + decision making (D) presenting a single, well-defined and unambiguous course + of action for the consultant to overcome skills deficits' + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. Caplan defines two type of consultation. Client-centered case consultation + aims to handle client's problems, while consultee-centered case consultation + aims to identify the reason of client's difficulty to solve problems. The answer + is (B). + - question: 'According to the Individuals with Disabilities Education Improvement Act, + which of the following must an educational agency do before it changes the educational + placement of a student with a disability? + + (A) Give the child a trial period in the new environment (B) Notify the parents + in writing (C) Obtain school board approval (D) Obtain parental consent' + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. When the decision to change the educational placement of a student + with a disability is made, the educational agency must notify the parents in + writing on that date. The answer is (B). + - question: 'While swimming in the ocean, Ivan is frightened by a dark shadow in the + water even before he has the chance to identify what the shadow is. The synaptic + connections taking place during this incident of fright are best described by + which of the following? + + (A) Messages are sent from the thalamus directly to the amygdala. (B) Messages + are sent from the thalamus to the "what" and "where" pathways. (C) Messages + are sent from the parasympathetic nervous system to the cerebral cortex. (D) + Messages are sent from the frontal lobes to the pituitary gland.' + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. Our neural system has a mechanism that can respond immediate emotional + signal before going to the thought center. In the Ivan's case, messages travel + directly from thalamus to amygdala. The answer is (A). + - question: 'Ani believes that her attitudes and behavior play a central role in what + happens to her. Such a belief is likely to be associated with + + (A) a strong superego. (B) low self-esteem. (C) low self-efficacy. (D) an internal + locus of control.' + target: 'Let''s think step by step. We refer to Wikipedia articles on psychology + for help. People with an external locus of control believes fate and luck play + an important role in their lives, while people with an internal locus of control + believes they control their lives. The answer is (D).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_psychology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_statistics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_statistics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..66ed702a6b9b57dc69719a13923e04ebc8218ea9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_statistics.yaml @@ -0,0 +1,81 @@ +dataset_name: high_school_statistics +description: The following are multiple choice questions (with answers) about high + school statistics. +fewshot_config: + sampler: first_n + samples: + - question: 'A new smartwatch is manufactured in one part of a factory, then secured + for shipping in another, independent part of the factory. The weight of the + smartwatch has a mean of 62 grams and a standard deviation of 1.0 grams. The + weight of the packaging (box, user''s guide, bubble wrap, etc.) has a mean of + 456 grams and a standard deviation of 6 grams. Together, the distribution of + the weight of the smartwatch and its packaging would have the following mean + and standard deviation: + + (A) Mean 518 grams; standard deviation 7.0 grams (B) Mean 518 grams; standard + deviation 3.5 grams (C) Mean 518 grams; standard deviation 6.1 grams (D) Mean + 394 grams; standard deviation 6.1 grams' + target: Let's think step by step. Since the weight of the watch and the weight + of the packaging are independent random variables, the mean and variance of + their sum is equal to the sum of their individual means and variances. So the + mean is 62 + 456 = 518 grams, and the variances is 1.0^2 + 6.0^2 = 37, leading + to a standard deviation of 6.1 grams. The answer is (C). + - question: 'After a frost warning was issued, the owner of a large orange grove asked + his workers to spray all his trees with water. The water was supposed to freeze + and form a protective covering of ice around the orange blossom. Nevertheless, + the owner suspected that some trees suffered considerable damage due to the + frost. To estimate the proportion of trees that suffered more than 50 percent + damage due to the frost, he took a random sample of 100 trees from his grove. + What is the response variable in this experiment? + + (A) The proportion of trees that suffered more than 50 percent damage due to + frost. (B) The number of trees affected by the frost. (C) The number of trees + sampled from the grove. (D) For each sampled tree, whether it suffered more + than 50 percent damage or at most 50 percent damage.' + target: Let's think step by step. In this experiment, the response variable is + what is measured. For each tree, what is measured is whether or not it suffered + more than 50 percent damage due to the frost. The answer is (D). + - question: 'Suppose X and Y are random variables with E(X) = 37, var(X) = 5, E(Y) + = 62, and var(Y) = 12. What are the expected value and variance of the random + variable X + Y? + + (A) E(X + Y) = 99, var(X + Y) = 8.5 (B) E(X + Y) = 99, var(X + Y) = 13 (C) E(X + + Y) = 99, var(X + Y) = 17 (D) There is insufficient information to answer this + question.' + target: Let's think step by step. While means of sums of random variables add + (regardless of whether the variables are independent) in order to determine + the variance of a sum of random variables, we need to know not just their individual + variances but the covariance of the two variables, which is not given in this + problem. The answer is (D). + - question: 'Which of the following sets has the smallest standard deviation? Which + has the largest? + + I: {1,2,3} + + II: {-10,10} + + III: {100} + + (A) I, II (B) II, III (C) III, I (D) III, II' + target: Let's think step by step. The variance of distribution I is the expected + squared deviation from its mean (which is 2), so the variance is 2/3 . The variance + of distribution II is 10^2 (because both elements are 10 away from the mean + of zero). The variance of distribution III is 0, since it has a single entry. + So distribution III has the smallest standard deviation and distribution II + has the largest. The answer is (D). + - question: 'Which of the following is a correct statement about correlation? + + (A) If the slope of the regression line is exactly 1, then the correlation is + exactly 1. (B) If the correlation is 0, then the slope of the regression line + is undefined. (C) Switching which variable is called x and which is called y + changes the sign of the correlation. (D) The correlation r is equal to the slope + of the regression line when z-scores for the y-variable are plotted against + z-scores for the x-variable.' + target: 'Let''s think step by step. Statement A is false because the slope of + the regression line being exactly 1 can occur even when the two variables are + not perfectly correlated. Statement B is false because uncorrelated variables + regression lines can have slope zero. Statement C is false because correlation + is symmetric in the two random variables. The answer is (D).' +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_statistics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8cea5109f6570086dce3cf1815dc50f1889d80ad --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_us_history.yaml @@ -0,0 +1,156 @@ +dataset_name: high_school_us_history +description: The following are multiple choice questions (with answers) about high + school us history. +fewshot_config: + sampler: first_n + samples: + - question: "This question refers to the following information.\nI come not to urge\ + \ personal claims, nor to seek individual benefits; I appear as the advocate\ + \ of those who cannot plead their own cause; I come as the friend of those who\ + \ are deserted, oppressed, and desolate. In the Providence of God, I am the\ + \ voice of the maniac whose piercing cries from the dreary dungeons of your\ + \ jails penetrate not your Halls of Legislation. I am the Hope of the poor crazed\ + \ beings who pine in the cells, and stalls, and cages, and waste rooms of your\ + \ poor-houses. I am the Revelation of hundreds of wailing, suffering creatures,\ + \ hidden in your private dwellings, and in pens and cabins\u2014shut out, cut\ + \ off from all healing influences, from all mind-restoring cares.\u2026 Could\ + \ their melancholy histories be spread before you as revealed to my grieved\ + \ spirit during the last three months, how promptly, how earnestly would you\ + \ search out the most approved means of relief; how trifling, how insignificant,\ + \ by comparison, would appear the sacrifices you are asked to make; how would\ + \ a few dimes and dollars, gathered from each citizen, diminish in value as\ + \ a possession, compared with the certain benefits and vast good to be secured\ + \ for the suffering insane...by the consecration and application of a sufficient\ + \ fund to the construction of a suitable hospital.\u2026\n\u2014Dorothea Dix,\ + \ Memorial Soliciting a State Hospital for the Protection and Cure of the Insane,\n\ + Submitted to the General Assembly of North Carolina, November 1848\nDorothea\ + \ Dix can best be compared to whom?\n(A) Abigail Adams (B) Clara Barton (C)\ + \ Shirley Temple (D) Hillary Clinton" + target: Let's think step by step. We refer to Wikipedia articles on us history + for help. Both Dorothea Dix and Clara barton are American nurses. The answer + is (B). + - question: "This question refers to the following information.\n\"As our late Conduct\ + \ at the Conestoga Manor and Lancaster have occasioned much Speculation & a\ + \ great diversity of Sentiments in this and neighboring Governments; some vindicating\ + \ & others condemning it; some charitably alleviating the Crime, & others maliciously\ + \ painting it in the most odious & detestable Colours, we think it our duty\ + \ to lay before the Publick, the whole Matter as it appeared, & still appears,\ + \ to us. . . .\n\"If these things are not sufficient to prove an unjustifiable\ + \ Attachment in the Quakers to the Indians Savages, a fixed Resolution to befriend\ + \ them & an utter insensibility to human Distresses, let us consider a few more\ + \ recent Facts. When we found the last Summer that we were likely to get no\ + \ Assistance from the Government, some Volunteers went out at our own Expense,\ + \ determined to drive our Enemies from our Borders; & when we came near to the\ + \ great Island, we understood that a Number of their Warriors had gone out against\ + \ our Frontiers. Upon this we returned and came up with them and fought with\ + \ them at the Munfey Hill where we lost some of our Men & killed some of their\ + \ Warriors & thereby saved our Frontiers from this Story in another Expedition.\ + \ But no sooner had we destroyed their Provisions on the great Island, & ruined\ + \ their trade with the good People at Bethlehem, but these very Indians, who\ + \ were justly suspected of having murdered our Friends in Northampton County,\ + \ were by the Influence of some Quakers taken under the Protection of the Government\ + \ to screen them from the Resentments of the Friends and Relations of the Murdered,\ + \ & to support them thro the Winter.\"\n\u2014\"Apology of the Paxton Boys\"\ + \ (pamphlet), 1764 (Note: \"apology\" in this context should be read as an explanation,\ + \ not an admission of guilt or regret.\nThe sentiments expressed in the explanation\ + \ above reflect which of the ongoing tensions during the colonial period of\ + \ American history?\n(A) Tensions between British policies and the aspirations\ + \ of North American colonists. (B) Tensions between American Indians allied\ + \ with the French and those allied with the British. (C) Tensions between freed\ + \ African Americans and white planters. (D) Tensions between backcountry settlers\ + \ and elites within colonial America." + target: Let's think step by step. We refer to Wikipedia articles on us history + for help. After the French and Indian War, the Scotch-Irish settlers attacked + American Indians. After the attacks on the Conestoga, about 250 Paxton Boys + present their grievances to the Pennsylvania legislature. As mentioned in the + information, the Paxton Boys cited resentiment at local elites. The answer is + (D). + - question: "This question refers to the following information.\nOur leaders talk about\ + \ stopping aggression from the north, but this was a struggle among groups of\ + \ Vietnamese until we intervened. We seem bent upon saving the Vietnamese from\ + \ Ho Chi Minh even if we have to kill them and demolish their country to do\ + \ it. As the native people survey bombed-out villages, women and children burned\ + \ by napalm, rice crops destroyed and cities overrun with our military personnel,\ + \ they are doubtless saying secretly of the Vietcong guerillas and of the American\ + \ forces, \"A plague on both your houses.\" \u2026 Stop the bombing, north and\ + \ south, end search and destroy offensive sweeps, and confine our military action\ + \ to holding operations on the ground. Bombing the north has failed to halt\ + \ or seriously check the flow of troops to the south and may, in fact, have\ + \ prompted a much greater war effort by Hanoi.\n\u2014Senator George McGovern,\ + \ \"The Lessons of Vietnam,\" April 25, 1967\nWhich of the following opinions\ + \ from the 1960s most directly reflects the perspective of George McGovern's\ + \ speech?\n(A) Americans must maximize their technological edge in Vietnam.\ + \ (B) American bombing in Vietnam is step by step leading to progress in the\ + \ war. (C) American bombing in Vietnam is a failure. (D) America must not give\ + \ in to defeatism about the war in Vietnam." + target: Let's think step by step. We refer to Wikipedia articles on us history + for help. "Stop the bombing" and "Bombing the north has failed to halt or seriously + check the flow of troops to the south" indicate that the perspective of George + McGovern's speech is that Amerian bombing in Vietnam is a failure. The answer + is (C). + - question: "This question refers to the following information.\n\"In the new Code\ + \ of Laws which I suppose it will be necessary for you to make I desire you\ + \ would Remember the Ladies, and be more generous and favorable to them than\ + \ your ancestors. Do not put such unlimited power into the hands of the Husbands.\ + \ Remember all Men would be tyrants if they could. If particular care and attention\ + \ is not paid to the Ladies we are determined to foment a Rebellion, and will\ + \ not hold ourselves bound by any Laws in which we have no voice, or Representation.\"\ + \nAbigail Adams, in a letter to John Adams, 1776\n\"Special legislation for\ + \ woman has placed us in a most anomalous position. Women invested with the\ + \ rights of citizens in one section\u2014voters, jurors, office-holders\u2014\ + crossing an imaginary line, are subjects in the next. In some States, a married\ + \ woman may hold property and transact business in her own name; in others,\ + \ her earnings belong to her husband. In some States, a woman may testify against\ + \ her husband, sue and be sued in the courts; in others, she has no redress\ + \ in case of damage to person, property, or character. In case of divorce on\ + \ account of adultery in the husband, the innocent wife is held to possess no\ + \ right to children or property, unless by special decree of the court. But\ + \ in no State of the Union has the wife the right to her own person, or to any\ + \ part of the joint earnings of the co-partnership during the life of her husband.\ + \ In some States women may enter the law schools and practice in the courts;\ + \ in others they are forbidden. In some universities girls enjoy equal educational\ + \ advantages with boys, while many of the proudest institutions in the land\ + \ deny them admittance, though the sons of China, Japan and Africa are welcomed\ + \ there. But the privileges already granted in the several States are by no\ + \ means secure.\"\nSusan B. Anthony, \"Declaration of Rights for Women,\" July\ + \ 4, 1876\nThe sentiments expressed in the second excerpt by Susan B. Anthony\ + \ are most likely in support of\n(A) the Equal Rights Amendment (B) universal\ + \ suffrage (C) states' rights (D) prohibition" + target: Let's think step by step. We refer to Wikipedia articles on us history + for help. The above information mentioned that women are in an anomalous position + in terms of legislation. Women's earnings do not belong to themselves, or they + cannot testify against her husbands. Susan believes women should have equal + legal rights as men. The answer is (B). + - question: 'This question refers to the following information. + + "Society in every state is a blessing, but government even in its best state + is but a necessary evil; in its worst state an intolerable one; for when we + suffer, or are exposed to the same miseries by a government, which we might + expect in a country without government, our calamity is heightened by reflecting + that we furnish the means by which we suffer. Government, like dress, is the + badge of lost innocence; the palaces of kings are built on the ruins of the + bowers of paradise. For were the impulses of conscience clear, uniform, and + irresistibly obeyed, man would need no other lawgiver; but that not being the + case, he finds it necessary to surrender up a part of his property to furnish + means for the protection of the rest; and this he is induced to do by the same + prudence which in every other case advises him out of two evils to choose the + least. Wherefore, security being the true design and end of government, it unanswerably + follows that whatever form thereof appears most likely to ensure it to us, with + the least expense and greatest benefit, is preferable to all others." + + Thomas Paine, Common Sense, 1776 + + Which of the following "miseries" alluded to above were most condemned by Anti-Federalists + of the post-Revolutionary era? + + (A) Organized response to Bacon''s Rebellion (B) Federal response to Shays''s + Rebellion (C) Federal response to the Whiskey Rebellion (D) Federal response + to Pontiac''s Rebellion' + target: 'Let''s think step by step. We refer to Wikipedia articles on us history + for help. Anti-Federalists do not believe centralized government power, and + suspect Washington''s military response to Whiskey Rebellion. Bacon''s Rebellion + and Pontiac''s Rebellion happen before the Revolution and they can be ruled + out. The answer is (C).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_us_history diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2fbdaf05c137270f4ff4207e7c6ce81c2a34d30c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_high_school_world_history.yaml @@ -0,0 +1,100 @@ +dataset_name: high_school_world_history +description: The following are multiple choice questions (with answers) about high + school world history. +fewshot_config: + sampler: first_n + samples: + - question: "This question refers to the following information.\n\"At least one of\ + \ the [world's] societies would have to somehow enormously increase its productivity\ + \ [in order to achieve global hegemony]. That quantum jump would have to be\ + \ made before the various scientific, technological, agricultural, and industrial\ + \ revolutions on which our post-quantum-leap world rests. It could only be accomplished\ + \ by exploiting the ecosystems, mineral resources, and human assets of whole\ + \ continents outside the lands of the society making the jump. Western Europe\ + \ did just that by means of its brutality and guns and, more important, by geographical\ + \ and ecological luck.\"\nCopyright \xA9 2015 Cambridge University Press.\n\ + Alfred Crosby, historian, Ecological Imperialism, 2004\nThe \"quantum jump\"\ + \ mentioned in the passage most directly contributed to which of the following\ + \ developments in the period 1450\u20131750 C.E.?\n(A) A breakdown in trade\ + \ routes through the collapse of the established state structure (B) An increase\ + \ in the population of the world through more plentiful supplies of food (C)\ + \ The spread of Chinese and Indian belief systems across the world (D) An increase\ + \ in social unrest" + target: Let's think step by step. We refer to Wikipedia articles on world history + for help. The "quantum jump" mentioned in the passage refers to the conquest + of the New World and the Columbian Exchange. Choice (A) and (C) did not happen + in history. Choice (C) refers to the human assets. The answer is (B). + - question: "This question refers to the following information.\n\"The struggle against\ + \ neo-colonialism is not aimed at excluding the capital of the developed world\ + \ from operating in less developed countries. It is aimed at preventing the\ + \ financial power of the developed countries being used in such a way as to\ + \ impoverish the less developed.\nNon-alignment, as practiced by Ghana and many\ + \ other countries, is based on co-operation with all States whether they be\ + \ capitalist, socialist or have a mixed economy. Such a policy, therefore, involves\ + \ foreign investment from capitalist countries, but it must be invested in accordance\ + \ with a national plan drawn up by the government of the non-aligned State with\ + \ its own interests in mind. The issue is not what return the foreign investor\ + \ receives on his investments\u2026The question is one of power. A State in\ + \ the grip of neo-colonialism is not master of its own destiny.\"\nKwame Nkrumah,\ + \ Neo-Colonialism, 1965\nWhich of the following provides the best context for\ + \ Nkrumah's writings?\n(A) The Industrial Revolution (B) Decolonization (C)\ + \ Regional Free Trade Associations (D) Autarky" + target: Let's think step by step. We refer to Wikipedia articles on world history + for help. The passage expresses a point that the successful fight against neo-colonialism + were in danger and the newly independent nations like Ghana may be re-colonized + via financial power of the developed countries. The answer is (B). + - question: "This question refers to the following information.\n\"Indeed, as both\ + \ the fatwas of distinguished [scholars] who base their opinion on reason and\ + \ tradition alike and the consensus of the Sunni community agree that the ancient\ + \ obligation of extirpation, extermination, and expulsion of evil innovation\ + \ must be the aim of our exalted aspiration, for \"Religious zeal is a victory\ + \ for the Faith of God the Beneficent\"; then, in accordance with the words\ + \ of the Prophet (Peace upon him!) \"Whosoever introduces evil innovation into\ + \ our order must be expelled\" and \"Whosoever does aught against our order\ + \ must be expelled,\" action has become necessary and exigent\u2026\"\nLetter\ + \ from Ottoman Sultan Selim I to Safavid Shah Ismail I, 1514\nThe letter from\ + \ Selim I is most clearly an example of which of the following?\n(A) The maintenance\ + \ of military supremacy at all costs (B) Expanding tensions between religious\ + \ sects (C) Factors that brought about the collapse of the Ottoman Empire (D)\ + \ Peacemaking efforts among the Islamic empires" + target: Let's think step by step. We refer to Wikipedia articles on world history + for help. The passage is an example of expanding tensions between Selim and + Ismail. In the passage the Selim references the fatwa and the consensus of the + Sunni community to against whosoever introduces evil. The answer is (B). + - question: 'This question refers to the following information. + + "The real grievance of the worker is the insecurity of his existence; he is + not sure that he will always have work, he is not sure that he will always be + healthy, and he foresees that he will one day be old and unfit to work. If he + falls into poverty, even if only through a prolonged illness, he is then completely + helpless, exam_ins to his own devices, and society does not currently recognize + any real obligation towards him beyond the usual help for the poor, even if + he has been working all the time ever so faithfully and diligently. The usual + help for the poor, however, leaves a lot to be desired, especially in large + cities, where it is very much worse than in the country." + + Otto von Bismarck, 1884 + + Otto von Bismarck likely made this speech in reaction to which of the following + issues? + + (A) Social acceptance of child labor (B) Declining life expectancy in Germany + (C) Criticisms of German trade tariffs (D) Negative effects attributed to industrial + capitalism' + target: Let's think step by step. We refer to Wikipedia articles on world history + for help. The passage talks about the grievance of the work under the industrial + capitalism. The answer is (D). + - question: "This question refers to the following information.\nHe contains all works\ + \ and desires and all perfumes and all tastes. He enfolds the whole universe\ + \ and in silence is loving to all. This is the Spirit that is in my heart, this\ + \ is Brahman. To him I shall come when I go beyond this life, and to him will\ + \ come he who has faith and doubts not.\n\u2014The Upanishads, India, c. 1000\ + \ BCE\nTo which religion does the speaker most likely belong?\n(A) Hinduism\ + \ (B) Buddhism (C) Shintoism (D) Zoroastrianism" + target: 'Let''s think step by step. We refer to Wikipedia articles on world history + for help. Brahman refers to the ultimate reality of all things in the Hindu + religion. In contrast, Buddhism does not have a concept of supreme God. The + answer is (A).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_high_school_world_history diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_human_aging.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3eec010845fa68ad974bdb7cd922a0028365d96e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_human_aging.yaml @@ -0,0 +1,42 @@ +dataset_name: human_aging +description: The following are multiple choice questions (with answers) about human + aging. +fewshot_config: + sampler: first_n + samples: + - question: 'All other things being equal, which of the following persons is more likely + to show osteoporosis? + + (A) An older Hispanic American woman (B) An older African American woman (C) + An older Asian American woman (D) An older Native American woman' + target: Let's think step by step. We refer to Wikipedia articles on human aging + for help. Although osteoporosis can occur at any age, the risk is higher for + older people. It is most common in Asian and non-Hispanic white women. The answer + is (C). + - question: 'The finding that adults tend to remember events from their adolescence + better than from other periods in their lives is referred to as the + + (A) Adolescence advantage (B) Reminiscence bump (C) Memorial memorial (D) Quadratic + retrieval spike' + target: Let's think step by step. We refer to Wikipedia articles on human aging + for help. Reminiscence bump is a phenomenon that older adults tend to recollect + events during their young ages. People usually have a period of childhood amnesia + from birth to around age 5, and a reminiscence bump between 10 and 30. The answer + is (B). + - question: 'Which element in tobacco smoke is responsible for cancers? + + (A) Nicotine (B) Tar (C) Carbon monoxide (D) Smoke particles' + target: Let's think step by step. We refer to Wikipedia articles on human aging + for help. The benzene, acrylamide and acrylonitrile in tar interact with the + lungs and cause DNA mutations in cells of the lungs, and lead to cancer. The + answer is (B). + - question: 'When older adults move to a new state after retirement, which of the following + is the more likely destination? + + (A) Texas (B) California (C) Hawaii (D) Vermont' + target: 'Let''s think step by step. We refer to Wikipedia articles on human aging + for help. Texas does not have state tax, and has low cost of living compared + with the other three options. The answer is (A).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_human_aging diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dab78f0c5ec1042d23240bb71f59b212885585aa --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_human_sexuality.yaml @@ -0,0 +1,50 @@ +dataset_name: human_sexuality +description: The following are multiple choice questions (with answers) about human + sexuality. +fewshot_config: + sampler: first_n + samples: + - question: 'The most common disorder among men who seek sexual therapy is: + + (A) premature ejaculation (B) inhibited ejaculation (C) erectile disorder (D) + ejaculatory disorder' + target: Let's think step by step. We refer to Wikipedia articles on human sexuality + for help. The most common type of arousal disorder for men is erectile dysfunction, + meaning that a person is not able to get any physical satisfaction from sexual + activity although he may be interested in it. The answer is (C). + - question: 'A woman who knows she has active herpes and untreated syphilis but continues + to have sex without informing her partners of her condition has, in psychoanalytic + terms: + + (A) a strong ego (B) a weak superego (C) a weak id (D) a strong superego' + target: Let's think step by step. We refer to Wikipedia articles on human sexuality + for help. A person with weak superego tends to be delinquent, criminal or have + antisocial personality. The action of the woman who knows she has active venereal + disease but still have sex with her partners indicate she may has antisocial + personality. The answer is (B). + - question: 'The nature of homosexual activities that occur during preadolescence include + all but which of the following? + + (A) sexual intercourse (B) circle jerks (C) exhibitionism (D) touching each + other''s genitals' + target: Let's think step by step. We refer to Wikipedia articles on human sexuality + for help. Sexual intercourse is prohibited during preadolescence. The answer + is (A). + - question: 'Women''s ability to have multiple orgasms is primarily due to: + + (A) the fact that they do not have a refractory period. (B) the response of + the inner layers of the vagina. (C) having alternating orgasms in different + locations. (D) the G-Spot.' + target: Let's think step by step. We refer to Wikipedia articles on human sexuality + for help. The refractory period is the time when a person is not able to be + erect or is not interested in sex. The answer is (A). + - question: 'Morning sickness is typically a problem: + + (A) during the first trimester (B) during the second trimester (C) during the + third trimester (D) all through the pregnancy' + target: 'Let''s think step by step. We refer to Wikipedia articles on human sexuality + for help. Morning sickness usually begins by nine weeks after conception, corresponding + to the first trimester. The answer is (A).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_human_sexuality diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_international_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_international_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99341f395352d7b6a9a8d1a71005ca821ac9b723 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_international_law.yaml @@ -0,0 +1,70 @@ +dataset_name: international_law +description: The following are multiple choice questions (with answers) about international + law. +fewshot_config: + sampler: first_n + samples: + - question: 'How the consent to be bound of a State may be expressed? + + (A) The consent of a State to be bound is expressed only by ratification (B) + The consent of a state to be bound by a treaty may be expressed by signature, + ratification, acceptance, approval or accession (C) The consent of a State to + be bound is expressed by signature (D) The consent of a State to be bound is + expressed by whatever means they choose' + target: Let's think step by step. We refer to Wikipedia articles on international + law for help. Article 11 of Vienna Convention on the Law of Treaties signed + in 1969 states that "the consent of a State to be bound by a treaty may be expressed + by signature, exchange of instruments constituting a treaty, ratification, acceptance, + approval or accession, or by any other means if so agreed." (B) is the most + precise and accurate answer. The answer is (B). + - question: 'What is the judge ad hoc? + + (A) If a party to a contentious case before the ICJ does not have a national + sitting as judge, it is entitled to nominate someone as a judge solely for that + case, with the title of judge ad hoc (B) Judge ad hoc is the member of the bench + of the ICJ with a casting vote (C) Judge ad hoc is a surrogate judge, in case + a judge is disqualified or passes away (D) Judge ad hoc is the judge that each + party will always nominate in every contentious case' + target: Let's think step by step. We refer to Wikipedia articles on international + law for help. As "ad hoc" implies, a judge ad hoc is appointed only for a specific + case or period, when a party to a contentious case before the International + Court of Justice does not have a regular national sitting as judge. The answer + is (A). + - question: 'When ''consent'' can serve as a circumstance precluding the wrongfulness + of a State conduct? + + (A) Consent can serve as a circumstance precluding the wrongfulness whenever + it is given (B) Consent can never serve as a circumstance precluding wrongfulness + (C) Consent can serve as a circumstance precluding wrongfulness, provided the + consent is valid and to the extent that the conduct remains within the limits + of the consent given (D) Consent can always serve as a circumstance precluding + wrongfulness, no matter which organ of the State gives it' + target: Let's think step by step. We refer to Wikipedia articles on international + law for help. Valid consent can serve as a circumstance precluding the wrongfulness + of a State conduct if the conduct remains within the limits of that consent, + according to Chapter V of the Responsibility of States for Internationally Wrongful + Acts, 2001, United Nations. The answer is (C). + - question: 'Would a reservation to the definition of torture in the ICCPR be acceptable + in contemporary practice? + + (A) This is an acceptable reservation if the reserving country''s legislation + employs a different definition (B) This is an unacceptable reservation because + it contravenes the object and purpose of the ICCPR (C) This is an unacceptable + reservation because the definition of torture in the ICCPR is consistent with + customary international law (D) This is an acceptable reservation because under + general international law States have the right to enter reservations to treaties' + target: Let's think step by step. We refer to Wikipedia articles on international + law for help. For it contravenes the object and purpose of the ICCPR, this is + an unacceptable reservation in contemporary practice. The answer is (B). + - question: 'What types of force does Article 2(4) of the UN Charter prohibit? + + (A) Article 2(4) encompasses only armed force (B) Article 2(4) encompasses all + types of force, including sanctions (C) Article 2(4) encompasses all interference + in the domestic affairs of States (D) Article 2(4) encompasses force directed + only against a State''s territorial integrity' + target: 'Let''s think step by step. We refer to Wikipedia articles on international + law for help. Article 2(4) of the UN Charter prohibits states from using armed + forces in their international relations. The answer is (A).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_international_law diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c714f7e595a48d28b750479a46183a02fc24dc0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_jurisprudence.yaml @@ -0,0 +1,59 @@ +dataset_name: jurisprudence +description: The following are multiple choice questions (with answers) about jurisprudence. +fewshot_config: + sampler: first_n + samples: + - question: 'Iverson Jewelers wrote a letter to Miller, ''We have received an exceptionally + fine self winding Rolox watch which we will sell to you at a very favorable + price.'' + + (A) The letter is an offer to sell (B) A valid offer cannot be made by letter. + (C) The letter contains a valid offer which will terminate within a reasonable + time. (D) The letter lacks one of the essential elements of an offer.' + target: Let's think step by step. We refer to Wikipedia articles on jurisprudence + for help. An offer shows the intent to enter into a mutually-beneficial contract + with specific terms. An offer can be made by a letter. While this letter indicates + the willingness to sell, the lack of specific terms, such as transaction price + and offer expiration date, makes it an incomplete offer. The answer is (D). + - question: 'Functions of the law include all but which of the following? + + (A) maximizing individual freedom (B) providing a basis for compromise (C) keeping + the peace (D) promoting the principles of the free enterprise system' + target: Let's think step by step. We refer to Wikipedia articles on jurisprudence + for help. Laws are fundamentally about helping resolve disputes between individuals, + and therefore essential for maximizing individual freedom, providing a basis + for compromise, and keeping the peace. The answer is (D). + - question: 'The ________ School of jurisprudence postulates that the law is based + on what is "correct." + + (A) Natural Law (B) Analytical (C) Historical (D) Sociological' + target: Let's think step by step. We refer to Wikipedia articles on jurisprudence + for help. Natural Law School of jurisprudence focuses on the laws of nature, + and states that the law should be based on ethics, morals, and what is "correct". + Analytical deals with the law as it already exists, Historical postulates that + the law was found and not made, and Sociological studies how the law and society + impact each other. The answer is (A). + - question: 'Which word best summarizes Weber''s explanation of the development of + formally rational law? + + (A) Authority. (B) Charisma. (C) Co-operation. (D) Capitalism.' + target: Let's think step by step. We refer to Wikipedia articles on jurisprudence + for help. Weber explained the development of formal rationality in laws as how + the modern society moved from tradition to rationality, where people decide + actions based less on how they were culturally done and more on expected utilities. + How rational individuals optimize efficiency of accomplishing tasks for higher + rewards is a core principle of Capitalism. The answer is (D). + - question: 'Which position does Rawls claim is the least likely to be adopted by the + POP (people in the original position)? + + (A) The POP would choose equality above liberty. (B) The POP would opt for the + ''maximin'' strategy. (C) The POP would opt for the ''difference principle''. + (D) The POP would reject the ''system of natural liberty.''' + target: 'Let''s think step by step. We refer to Wikipedia articles on jurisprudence + for help. The POP would opt for the ''maximin'' strategy, opt for the ''difference + principle'', and reject the ''system of natural liberty'', but the POP would + not choose equality above liberty, since the POP assume both equal and free + citizens. The answer is (A).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_jurisprudence diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd72cab8cf8ea0553a633810e59c0ba4059c2176 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_logical_fallacies.yaml @@ -0,0 +1,61 @@ +dataset_name: logical_fallacies +description: The following are multiple choice questions (with answers) about logical + fallacies. +fewshot_config: + sampler: first_n + samples: + - question: 'When an arguer causes confusion during refutation because of real or feigned + lack of an ability to engage in refutation, that arguer may have committed the + fallacy of + + (A) poor sportsmanship (B) appeal to compassion (C) argument against the person + (D) ignorance of refutation' + target: Let's think step by step. We refer to Wikipedia articles on logical fallacies + for help. Ignorance of refutation, one of Aristotle's original list of logical + fallacies in his Organon, is when someone causes confusion in an argument through + real or feigned inability to engage in refutation, in order to win the argument. + The answer is (D). + - question: 'The complex question fallacy consists of + + (A) arguing something is inferior just because it doesn''t do something it was + never intended to do. (B) including more than one claim in the proposition and + treating proof for one claim as proof for all the claims. (C) drawing a conclusion + before examining the evidence, and only considering evidence that supports that + conclusion. (D) asking a question that includes either an unproven assumption + or more than one question, thus making a straightforward yes or no answer meaningless.' + target: Let's think step by step. We refer to Wikipedia articles on logical fallacies + for help. The complex question fallacy is when someone makes a single yes or + no answer to a question meaningless, by including either an unproven assumption + or many questions. The latter is also known as the many questions fallacy. The + answer is (D). + - question: 'Arguing that what is true of the parts must be true of the whole is the + fallacy of... + + (A) Division (B) Composition (C) Appeal to the person (D) Appeal to ignorance' + target: Let's think step by step. We refer to Wikipedia articles on logical fallacies + for help. Fallacy of composition occurs when someone argues what is true of + the parts must be true of the whole. The answer is (B). + - question: 'Which of the following is true of a valid categorical syllogism? + + (A) The minor premise must deny the antecedent (B) The major premise must affirm + the consequent (C) The middle term must be used in at least one premise in a + universal or unqualified sense (D) All of the above' + target: 'Let''s think step by step. We refer to Wikipedia articles on logical + fallacies for help. A valid categorical syllogism must satisfy several conditions: + (1) the syllogism must have exactly three terms (2) every term of the syllogism + must be used twice exactly, (3) a term may be used only once in any premise, + and (4) the middle term must be used in at least one premise in a universal + or unqualified sense, etc. Only (C) is true. The answer is (C).' + - question: 'If someone attacks the character of an opposing arguer, instead of responding + to that opponent''s arguments, the first person has probably committed which + of the following fallacies? + + (A) tu quoque (B) horse laugh (C) argument against the person (D) ignoratio + elenchi' + target: 'Let''s think step by step. We refer to Wikipedia articles on logical + fallacies for help. The argument against the person fallacy occurs when someone + irrelevantly attacks the character of an opposing arguer, instead of addressing + that opponent''s arguments. The answer is (C).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_logical_fallacies diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33622ac4e7291eb380b6e58382e4fe84052a5bdc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_machine_learning.yaml @@ -0,0 +1,74 @@ +dataset_name: machine_learning +description: The following are multiple choice questions (with answers) about machine + learning. +fewshot_config: + sampler: first_n + samples: + - question: 'Which image data augmentation is most common for natural images? + + (A) random crop and horizontal flip (B) random crop and vertical flip (C) posterization + (D) dithering' + target: Let's think step by step. Data augmentation is used to increase the diversity + of images in the training dataset. It is important that natural images are kept + natural after being augmented. Vertical flips of images are not natural, so + (B) is false. Posterization makes the image look like a poster and and dithering + increases color depth. None of these two preserve the natural property. The + only natural data augmentation technique is (A). The answer is (A). + - question: "Traditionally, when we have a real-valued question attribute during decision-tree\ + \ learning we consider a binary split according to whether the attribute is\ + \ above or below some threshold. Pat suggests that instead we should just have\ + \ a multiway split with one branch for each of the distinct values of the attribute.\ + \ From the list below choose the single biggest problem with Pat\u2019s suggestion:\n\ + (A) It is too computationally expensive. (B) It would probably result in a decision\ + \ tree that scores badly on the training set and a testset. (C) It would probably\ + \ result in a decision tree that scores well on the training set but badly on\ + \ a testset. (D) It would probably result in a decision tree that scores well\ + \ on a testset but badly on a training set." + target: "Let's think step by step. Because the question is real valued, it is unlikely\ + \ that the same values appear both at training and test time. This means that\ + \ while such a decision tree could yield good performance on the training data,\ + \ when evaluated on the test data it will perform badly because the decision\ + \ tree won\u2019t know what to do with numbers that did not appear in the training\ + \ data. The answer is (C)." + - question: "You are reviewing papers for the World\u2019s Fanciest Machine Learning\ + \ Conference, and you see submissions with the following claims. Which ones\ + \ would you consider accepting?\n(A) My method achieves a training error lower\ + \ than all previous methods! (B) My method achieves a test error lower than\ + \ all previous methods! (Footnote: When regularisation parameter \u03BB is chosen\ + \ so as to minimise test error.) (C) My method achieves a test error lower than\ + \ all previous methods! (Footnote: When regularisation parameter \u03BB is chosen\ + \ so as to minimise cross-validaton error.) (D) My method achieves a cross-validation\ + \ error lower than all previous methods! (Footnote: When regularisation parameter\ + \ \u03BB is chosen so as to minimise cross-validaton error.)" + target: "Let's think step by step. In machine learning, we train with some data\ + \ and fixed hyperparameters and the training error can be arbitrarily low, so\ + \ (A) can\u2019t be right. Then, one compares different hyperparameters by selecting\ + \ the model with the lowest cross-validation error, this means that (B) and\ + \ (D) are not the right procedure. The only relevant number after these is the\ + \ test error and thus (C) is the right answer. The answer is (C)." + - question: 'A 6-sided die is rolled 15 times and the results are: side 1 comes up + 0 times; side 2: 1 time; side 3: 2 times; side 4: 3 times; side 5: 4 times; + side 6: 5 times. Based on these results, what is the probability of side 3 coming + up when using Add-1 Smoothing? + + (A) 2.0/15 (B) 1.0/7 (C) 3.0/16 (D) 1.0/5' + target: 'Let''s think step by step. Add-1 smoothing adds the value of one to the + different counts and then normalizes the probabilities accordingly. The counts + after adding one will be: side 1 comes up 1 time; side 2: 2 times; side 3: 3 + times; side 4: 4 times; side 5: 5 times; side 6: 6 times. The number of sum + one die rolls will be 21, so the probability of drawing a three is 3/21 = 1/7. + The answer is (B).' + - question: 'To achieve an 0/1 loss estimate that is less than 1 percent of the true + 0/1 loss (with probability 95%), according to Hoeffding''s inequality the IID + test set must have how many examples? + + (A) around 10 examples (B) around 100 examples (C) between 100 and 500 examples + (D) more than 1000 examples' + target: "Let's think step by step. By the Hoeffding\u2019s inequality, we expect\ + \ that with 95% probability the in-sample and out-of-sample errors differ by\ + \ epsilon when we have N samples if 2 exp(-2 epsilon^2 N)<0.05, this implies\ + \ that N > -1/(2*epsilon**2) log ( 0.05/2 )= log (40)*5000. Since log(40)>1,\ + \ we have that one needs more than 1000 examples. The answer is (D).\n\n" +tag: mmlu_flan_cot_fewshot_stem +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_machine_learning diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_management.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87d9ba8c9aa31733a5849695213d97deff9c2ded --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_management.yaml @@ -0,0 +1,46 @@ +dataset_name: management +description: The following are multiple choice questions (with answers) about management. +fewshot_config: + sampler: first_n + samples: + - question: 'How can organisational structures that are characterised by democratic + and inclusive styles of management be described? + + (A) Hierarchical (B) Bureaucratic (C) Flat (D) Functional' + target: Let's think step by step. We refer to Wikipedia articles on management + for help. Flat organizational structures are characterized by democratic and + inclusive styles of management, and have few (if any) levels of management between + the workers and managers. The answer is (C). + - question: 'Hygiene factors are associated with which writer? + + (A) Frederick Hertzberg (B) D.C. McClelland (C) Abraham Maslow (D) Douglas McGregor' + target: Let's think step by step. We refer to Wikipedia articles on management + for help. Hygiene factors include compensation, company policies, supervision, + interpersonal relations, and work environments. Hertzberg lists them as factors + that cannot motivate employees but can minimize job dissatisfaction. The answer + is (A). + - question: 'What characteristic is not a key feature of the ''open systems'' model + of management? + + (A) Morale (B) Innovation (C) Growth resource (D) Adaptation' + target: Let's think step by step. We refer to Wikipedia articles on management + for help. The key characteristics of an open system in management include innovation, + growth resource, and adaption, but do not include morale. The answer is (A). + - question: 'Which element of the cultural web forms regalia? + + (A) Symbols (B) Rituals and routines (C) Power structures (D) Control systems' + target: Let's think step by step. We refer to Wikipedia articles on management + for help. The cultural web is a tool for mapping an organization's culture, + where symbols form the regalia that visually expresses the values that the organization + holds as important. The answer is (A). + - question: 'What are the two main dimensions of the Ohio Studies into leadership? + + (A) Starting position and end position (B) Initial environment and changed environment + (C) Organisational structure and conditioning (D) Initiating structure and considerations' + target: 'Let''s think step by step. We refer to Wikipedia articles on management + for help. The Ohio State Leadership Studies conducted in the 1940s identified + initiating structure and consideration as the two main dimensions of leader + behavior. The answer is (D).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_management diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_marketing.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..182eb52ec509c34c225a176774be653d747e120e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_marketing.yaml @@ -0,0 +1,56 @@ +dataset_name: marketing +description: The following are multiple choice questions (with answers) about marketing. +fewshot_config: + sampler: first_n + samples: + - question: 'Although the content and quality can be as controlled as direct mail, + response rates of this medium are lower because of the lack of a personal address + mechanism. This media format is known as: + + (A) Care lines. (B) Direct mail. (C) Inserts. (D) Door to door.' + target: Let's think step by step. We refer to Wikipedia articles on marketing + for help. Door to door marketing delivers non-addressed items within all buildings + within a geographic area. While it can control the content and quality as well + as direct mail marketing, its response rate is lower because of the lack of + a personal address mechanism. The answer is (D). + - question: 'In an organization, the group of people tasked with buying decisions is + referred to as the _______________. + + (A) Outsourcing unit. (B) Procurement centre. (C) Chief executive unit. (D) + Decision-making unit.' + target: Let's think step by step. We refer to Wikipedia articles on marketing + for help. In an organization, the group of the people tasked with buying decision + is referred to as the decision-making unit. The answer is (D). + - question: 'The single group within society that is most vulnerable to reference group + influence is: + + (A) The older consumer who feels somewhat left out of things. (B) The married + women, many of whom feel a need for stability in their lives. (C) New immigrants + who really want to assimilate into their new culture. (D) Children, who base + most of their buying decisions on outside influences.' + target: Let's think step by step. We refer to Wikipedia articles on marketing + for help. Children, who mostly based their buying decisions on outside influences, + are the single group within society that is more vulnerable to reference group + influence. The answer is (D). + - question: 'Which of the following is an assumption in Maslow''s hierarchy of needs? + + (A) Needs are dependent on culture and also on social class. (B) Lower-level + needs must be at least partially satisfied before higher needs can affect behaviour. + (C) Needs are not prioritized or arranged in any particular order. (D) Satisfied + needs are motivators, and new needs emerge when current needs remain unmet.' + target: Let's think step by step. We refer to Wikipedia articles on marketing + for help. Maslow's hierarchy of needs, from the bottom upwards, are physiological + (food and clothing), safety, love and belonging needs, esteem, and self-actualization. + Lower-level needs must be at least partially satisfied before higher ones can + affect behavior. The answer is (B). + - question: '_____________ is a natural outcome when combining demographic and geographic + variables. + + (A) Geodemographics (B) Product differentiation. (C) ANSOFF matrix. (D) Brand + management.' + target: 'Let''s think step by step. We refer to Wikipedia articles on marketing + for help. Geodemographics is a natural outcome when combining demographic and + geographic variables. The answer is (A).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_marketing diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..757010bfed0f08f882995eed787d4f68e0f8121c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_medical_genetics.yaml @@ -0,0 +1,51 @@ +dataset_name: medical_genetics +description: The following are multiple choice questions (with answers) about medical + genetics. +fewshot_config: + sampler: first_n + samples: + - question: 'The stage of meiosis in which chromosomes pair and cross over is: + + (A) prophase I (B) metaphase I (C) prophase II (D) metaphase II' + target: Let's think step by step. We refer to Wikipedia articles on medical genetics + for help. Prophase I is the stage of meiosis where homologous chromosomes pair + with each other and exchange genetic material. The answer is (A). + - question: 'DNA ligase is + + (A) an enzyme that joins fragments in normal DNA replication (B) an enzyme of + bacterial origin which cuts DNA at defined base sequences (C) an enzyme that + facilitates transcription of specific genes (D) an enzyme which limits the level + to which a particular nutrient reaches' + target: Let's think step by step. We refer to Wikipedia articles on medical genetics + for help. DNA ligase is a type of enzyme (EC 6.5.1.1) responsible for joining + DNA strands together by catalyzing a phosphodiester bond. The answer is (A). + - question: 'Which of the following conditions does not show multifactorial inheritance? + + (A) Pyloric stenosis (B) Schizophrenia (C) Spina bifida (neural tube defects) + (D) Marfan syndrome' + target: Let's think step by step. We refer to Wikipedia articles on medical genetics + for help. Multifactorial inheritance is when more than a single factor is responsible + for causing a given trait or health problem. Genes cannot be the only factor. + Marfan syndrome, on the other hand, requires only one abnormal copy of the of + the Marfan gene, from one parent, to inherit the trait. The answer is (D). + - question: 'A gene showing codominance + + (A) has both alleles independently expressed in the heterozygote (B) has one + allele dominant to the other (C) has alleles tightly linked on the same chromosome + (D) has alleles expressed at the same time in development' + target: Let's think step by step. We refer to Wikipedia articles on medical genetics + for help. Codominance, as it relates to genetics, refers to a type of genetic + inheritance where the phenotype of both the parents is easily observed in the + offspring. A heterozygote is an individual having two different alleles of a + gene. The answer is (A). + - question: 'Large triplet repeat expansions can be detected by: + + (A) polymerase chain reaction. (B) single strand conformational polymorphism + analysis. (C) Southern blotting. (D) Western blotting.' + target: 'Let''s think step by step. We refer to Wikipedia articles on medical + genetics for help. A Southern blot is a method in molecular biology for detecting + specific DNA sequences in a sample. Large triplet repeat expansions are usually + detected with this method. The answer is (C).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_medical_genetics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2fe892eb06f522df6e99606013d26e0df1517cf3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_miscellaneous.yaml @@ -0,0 +1,43 @@ +dataset_name: miscellaneous +description: The following are multiple choice questions (with answers) about miscellaneous. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of these songs was a Top 10 hit for the rock band The Police? + + (A) ''Radio Ga-Ga'' (B) ''Ob-la-di Ob-la-da'' (C) ''De Do Do Do De Da Da Da'' + (D) ''In-a-Gadda-Da-Vida''' + target: Let's think step by step. We refer to Wikipedia for help. Radio Ga-Ga + is by Queen. Ob-la-di Ob-la-da is by The Beatles. And In-a-Gadda-Da-Vida is + by Iron Butterfly. Leaving 'De Do Do Do De Da Da Da' as the only song by The + Police, and also a Top 10 hit. The answer is (C). + - question: 'What place is named in the title of the 1979 live album by rock legends + Cheap Trick? + + (A) Budapest (B) Budokan (C) Bhutan (D) Britain' + target: Let's think step by step. We refer to Wikipedia for help. Nippon Budokan + is an indoor arena in Tokyo, Japan renowned for hosting rock music concerts + including Cheap Trick in 1978. 'Cheap Trick at Budokan' became the name of their + album. The answer is (B). + - question: 'What is produced during photosynthesis? + + (A) hydrogen (B) nylon (C) oxygen (D) light' + target: Let's think step by step. We refer to Wikipedia for help. Photosynthesis + is the process in which green plants use the green pigment chlorophyll to synthesize + foods with water and carbon dioxide. Oxygen is the byproduct of this process. + The answer is (C). + - question: 'Who is the shortest man to ever win an NBA slam dunk competition? + + (A) Anthony ''Spud'' Webb (B) Michael ''Air'' Jordan (C) Tyrone ''Muggsy'' Bogues + (D) Julius ''Dr J'' Erving' + target: Let's think step by step. We refer to Wikipedia for help. In 1986, Spud + Webb, standing only 5'7" became the shortest NBA player in history to win an + official slam dunk contest. The answer is (A). + - question: 'How many axles does a standard automobile have? + + (A) one (B) two (C) four (D) eight' + target: 'Let''s think step by step. We refer to Wikipedia for help. Most cars + have two axles to rotate the wheels.. The answer is (B).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_miscellaneous diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45a92e075582d6c0e2eb11c0310f87fc5debb4bb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_moral_disputes.yaml @@ -0,0 +1,64 @@ +dataset_name: moral_disputes +description: The following are multiple choice questions (with answers) about moral + disputes. +fewshot_config: + sampler: first_n + samples: + - question: 'Baron admits that the versions of the ticking bomb hypothetical she discusses + are "stunningly stupid," but she claims this is actually evidence of + + (A) the stupidity of most traditional philosophical examples. (B) a general + lack of intelligence among people with advanced degrees. (C) the wrongness of + torture. (D) the readiness on the part of many intelligent people to see torture + as the best solution to deal with terrorism.' + target: Let's think step by step. We refer to Wikipedia articles on moral disputes + for help. The ticking bomb hypothetical poses a problem where many people will + die to an exploding bomb, if the hypothetical terrorist does not disclose how + to defuse it. Baron sees this hypothetical as silly, but its prevalence does + suggest intelligent people, particularly utilitarians, see torture as justifiable + to save the lives in this scenario. The answer is (D). + - question: 'A fertilized ovum is also known as + + (A) a zygote. (B) an embryo. (C) a viability. (D) a blastocyst.' + target: Let's think step by step. We refer to Wikipedia articles on moral disputes + for help. Once a single sperm penetrates the layers of an egg to form a new + cell, that cell is called a zygote. The answer is (A). + - question: 'Pence compares six different cases of reproduction, from natural twinning + to SCNT. What conclusion does he draw from this comparison? + + (A) SCNT is not a different kind of reproduction because there are no morally + relevant differences between it and other permissible means of reproduction. + (B) Because there is a low risk of harm for natural twinning, there will be + a low risk of harm for SCNT. (C) both A and B (D) neither A nor B' + target: Let's think step by step. We refer to Wikipedia articles on moral disputes + for help. SCNT, known as 'somatic cell nuclear transfer', is a technique in + which the nucleus of a somatic (body) cell is transferred to the cytoplasm of + an enucleated egg. This is used in farm animal cloning. Pence may compare this + to other processes of reproduction, and see little moral separation. The answer + is (A). + - question: 'Which of the following is an intrinsic good, according to Singer? + + (A) being in the upper class (B) the ability to learn and solve problems (C) + being more beautiful than most of the population (D) being shorter than average' + target: Let's think step by step. We refer to Wikipedia articles on moral disputes + for help. Bioethicist Peter Singer sees intrinsic value as innate values conferred + by oneself, for oneself. Innanimate objects can be beautiful, short, or have + some valuable criteria, but capabilities are intrinsically good. The answer + is (B). + - question: 'According to Metz, what is wrong with consequentialist arguments against + capital punishment based on African values? + + (A) It is unclear as of yet whether or not capital punishment deters harm to + the community. (B) It is unclear as of yet whether or not capital punishment + deters harm to any individuals. (C) Consequentialism is not supported by African + values. (D) Even though consequentialism is supported by African values, no + consequentialist arguments framed in terms of African values have been offered.' + target: 'Let''s think step by step. We refer to Wikipedia articles on moral disputes + for help. Thaddeus Metz is a humanities research professor focusing on ethical + philosophy in South Africa. Metz has written the death penalty is unjustified + as it treats individuals as incapable of communal relations. It is unclear that + capital punishment is to the benefit of, or a deterrent of harm to the community. + The answer is (A).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_moral_disputes diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b4a0231930d62fe73a1cb84b58420baf0182365c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_moral_scenarios.yaml @@ -0,0 +1,62 @@ +dataset_name: moral_scenarios +description: The following are multiple choice questions (with answers) about moral + scenarios. +fewshot_config: + sampler: first_n + samples: + - question: 'For which of these two scenarios does the main character (who uses I/me/my) + do something clearly morally wrong, according to ordinary moral standards in + the US as of 2020? Scenario 1 | I took his wallet after he harassed me at the + bar. Scenario 2 | I made love to my newborn baby boy. + + (A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not + wrong' + target: Let's think step by step. We refer to Wikipedia articles on moral scenarios + for help. Stealing from someone is not considered an ethical response to harrassment. + And having sex with a minor is egregiously unethical. The answer is (A). + - question: 'For which of these two scenarios does the main character (who uses I/me/my) + do something clearly morally wrong, according to ordinary moral standards in + the US as of 2020? Scenario 1 | I asked for a cigarette from the guy outside + the store. Scenario 2 | When the security guard left his post, I pocketed a + rare coin. + + (A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not + wrong' + target: Let's think step by step. We refer to Wikipedia articles on moral scenarios + for help. Asking for something is not morally wrong. Stealing is morally wrong. + The answer is (C). + - question: 'For which of these two scenarios does the main character (who uses I/me/my) + do something clearly morally wrong, according to ordinary moral standards in + the US as of 2020? Scenario 1 | During the house''s demolition, I pushed him + into a wall. Scenario 2 | When I choose a babysitter for my twin nieces, I don''t + care about her racial background. + + (A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not + wrong' + target: Let's think step by step. We refer to Wikipedia articles on moral scenarios + for help. Physical violence is morally wrong. Being open to hiring people from + different racial backgrounds is an expected moral stance. The answer is (B). + - question: 'For which of these two scenarios does the main character (who uses I/me/my) + do something clearly morally wrong, according to ordinary moral standards in + the US as of 2020? Scenario 1 | I bought an old pick-up truck for work. Scenario + 2 | I looked up at the gorgeous waterfall and threw a rock at it. + + (A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not + wrong' + target: Let's think step by step. We refer to Wikipedia articles on moral scenarios + for help. Neither of these are wrong as no one is wronged by buing a pick-up + truck or throwing a rock into a stream. The answer is (D). + - question: 'For which of these two scenarios does the main character (who uses I/me/my) + do something clearly morally wrong, according to ordinary moral standards in + the US as of 2020? Scenario 1 | I really loved my newborn baby boy. Scenario + 2 | The new girl is embarrassed to live in a trailer park, so I invite the class + to her house. + + (A) Wrong, Wrong (B) Wrong, Not wrong (C) Not wrong, Wrong (D) Not wrong, Not + wrong' + target: 'Let''s think step by step. We refer to Wikipedia articles on moral scenarios + for help. Loving someone is not wrong. However, exposing something that someone + is embarrassed about could be considered quite mean. The answer is (C).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_moral_scenarios diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_nutrition.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..66498dc564350f893e7bd45078528b5750bce0ca --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_nutrition.yaml @@ -0,0 +1,63 @@ +dataset_name: nutrition +description: The following are multiple choice questions (with answers) about nutrition. +fewshot_config: + sampler: first_n + samples: + - question: 'What is the first-line drug for patients with type 2 diabetes and obesity, + as of 2020? + + (A) Acarbose (B) Metformin (C) Sulphonylureas (D) Insulin' + target: Let's think step by step. We refer to Wikipedia articles on nutrition + for help. Metformin (Fortamet, Glumetza, or others) is usually the first medication + prescribed for type 2 diabetes, as well as obesity. It works by lowering glucose + production in the liver and improving the body's sensitivity to insulin. The + answer is (B). + - question: 'Which of the following statements is correct (according to knowledge in + 2020)? + + (A) Consumers with phenylketonuria must avoid the consumption of the sweetener + aspartame (B) Consumers with phenylketonuria must avoid the consumption of the + sweetener saccharin (C) Consumers with phenylketonuria must avoid the consumption + of the sweetener sucralose (D) Consumers with phenylketonuria must avoid the + consumption of the sweetener acesulfame K' + target: Let's think step by step. We refer to Wikipedia articles on nutrition + for help. People with phenylketonuria (PKU) cannot break down the amino acid + phenylalanine. As it builds up in the blood and brain it can lead to brain damage. + People with PKU should avoid foods that are converted to phenylalanine in the + body, such as aspartame. The answer is (A). + - question: 'Which of the following statements about iodine is correct, as of 2020? + + (A) 50% of adults consume iodine at levels below the RNI (B) Dairy products + are a poor source of iodine (C) The iodine content of organic milk is generally + lower that the level in non-organic milk (D) UK dietary reference values recommend + an increase in iodine intake in pregnancy' + target: Let's think step by step. We refer to Wikipedia articles on nutrition + for help. Organic milk usually has less iodine content than non-organic milk. + The answer is (C). + - question: 'Which of the following is the most plausible explanation for the protective + effect of dietary fibre against cancer of the colon, as of 2020? + + (A) Propionic acid, formed during colonic fibre fermentation inhibits liver + fatty acid synthesis (B) Butyric acid, formed during colonic fibre fermentation + stimulates "silencing" of the SLC5A8 tumour suppressor gene (C) None of these + options are correct (D) Butyric acid, formed during colonic fibre fermentation + stimulates anti-oxidant defences in the colon' + target: Let's think step by step. We refer to Wikipedia articles on nutrition + for help. Dietary fibre is inversely proportional to the risk of colorectal + cancer. This is presumed because butyric acid (BA) stimulates antioxidants which + help protect the colon from cancerous tumors. The answer is (D). + - question: 'In a cohort study, the risk ratio of developing diabetes was 0.86 when + comparing consumers of tea (the exposed) to those who did not drink tea (the + unexposed). Which one statement is correct (according to knowledge in 2020)? + + (A) The tea drinkers have lower risk of developing diabetes. (B) The tea drinkers + have higher risk of developing diabetes. (C) Based on the information given + we cannot tell if the observed difference in disease risk is the result of chance. + (D) The risk ratio is close to the value one, so there is no difference in disease + risk between the two groups.' + target: 'Let''s think step by step. We refer to Wikipedia articles on nutrition + for help. The risk ratio is not sufficiently reduced that it could not be explained + by random chance given the studies sample size. The answer is (C).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_nutrition diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5116be01d6a463f23f9b7422de1b0b48bbdb55e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_philosophy.yaml @@ -0,0 +1,44 @@ +dataset_name: philosophy +description: The following are multiple choice questions (with answers) about philosophy. +fewshot_config: + sampler: first_n + samples: + - question: 'The study of reality in the broadest sense, an inquiry into the elemental + nature of the universe and the things in it, is known as _____. + + (A) metaphysics (B) epistemology (C) quantum physics (D) axiology' + target: Let's think step by step. We refer to Wikipedia articles on philosophy + for help. Among the options, only metaphysics studies the nature of reality + and existence. The answer is (A). + - question: "According to Moore\u2019s \u201Cideal utilitarianism,\u201D the right\ + \ action is the one that brings about the greatest amount of:\n(A) pleasure.\ + \ (B) happiness. (C) good. (D) virtue." + target: Let's think step by step. We refer to Wikipedia articles on philosophy + for help. Moore's "ideal utilitarianism" states that one's actions should maximize + intrinsic goods. The answer is (C). + - question: 'Before Tolstoy''s Christian conversion, what was his perspective on the + meaning of life? + + (A) optimist (B) satisfied (C) nominally religious (D) pessimist' + target: Let's think step by step. We refer to Wikipedia articles on philosophy + for help. Before his conversion, Tolstoy feels that life was uncertain, which + is a pessimist's point of view. The answer is (D). + - question: 'According to d''Holbach, people always act according to _____. + + (A) free choices (B) dictates of the soul (C) necessary natural laws (D) undetermined + will' + target: Let's think step by step. We refer to Wikipedia articles on philosophy + for help. d'Holbach believes that people act according to necessary laws, and + it proves nothing about people's free will. The answer is (C). + - question: 'Psychological egoism is: + + (A) an ethical theory about how we ought to behave. (B) a generalization concerning + the way people tend to behave. (C) a claim about human nature and the ways people + are capable of behaving. (D) none of the above.' + target: 'Let''s think step by step. We refer to Wikipedia articles on philosophy + for help. Psychological egoism suggests that one behaves based on what makes + one feels good, hence it is a claim about human nature and how humans are capable + of behaving. The answer is (C).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_philosophy diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_prehistory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_prehistory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6f9e5d81664445497a13a9adf7ca818ed6d2c7ef --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_prehistory.yaml @@ -0,0 +1,59 @@ +dataset_name: prehistory +description: The following are multiple choice questions (with answers) about prehistory. +fewshot_config: + sampler: first_n + samples: + - question: 'What is the approximate mean cranial capacity of Homo erectus? + + (A) under 650 cc (B) about 800 cc (C) just under 1000 cc (D) 1200 cc' + target: Let's think step by step. We refer to Wikipedia articles on prehistory + for help. The average cranium capacity of Homo erectus is less than 1000 cubic + cm. The answer is (C). + - question: 'According to Timothy Pauketat, the evidence for social stratification + and political power at Cahokia suggests: + + (A) a center of Mississippian civilization with conditions similar to the rise + of early states. (B) the limitations of authority in a Native American society + of egalitarian foragers. (C) a simple chiefdom or perhaps a complex chiefdom + had evolved by A.D. 1500. (D) a center of Mississippian civilization with conditions + similar to societies on the Northwest Coast of North America.' + target: Let's think step by step. We refer to Wikipedia articles on prehistory + for help. Timothy Pauketat is known for his research on Cahokia, the center + of the Mississippian culture, where he found similar conditions to the rise + of early states. The answer is (A). + - question: 'Recent research on hominid species dating from the Middle Pliocene indicates + there was (as of 2020): + + (A) a great amount of species diversity, or a single species that exhibited + a lot of diversity. (B) very little species diversity during this period and + very few hominids. (C) decreased species diversity due to a prolonged ice age + followed by a severe drought. (D) decreased species diversity but increased + numbers of hammerstones and flakes, indicating stone tool manufacture.' + target: Let's think step by step. We refer to Wikipedia articles on prehistory + for help. Recent research has recognized multiple hominid species from the Middle + Pliocene, meaning that there is a great amount of species diversity or diversity + in a single species. The answer is (A). + - question: 'Researchers now believe that the decline of the Maya was caused chiefly + by: + + (A) a cataclysm of some kind, such as an earthquake, volcano, or tsunami. (B) + ecological degradation resulting from slash-and-burn farming techniques. (C) + endless wars between neighboring Mayan city-states. (D) practices of interbreeding + that led to a steep rise in congenital disorders.' + target: Let's think step by step. We refer to Wikipedia articles on prehistory + for help. Researchers believe that the Maya collapse was mainly caused by over-exploitation + of natural resources like the slash-and-burn farming techniques. The answer + is (B). + - question: 'The great Mayan king Pacal built temples in the city of Palenque in order + to: + + (A) satisfy the powerful Mayan astronomer priests. (B) display his generosity + to the common people, since they were allowed to live in the temples. (C) frighten + away enemies, in particular the Spaniards. (D) legitimize his kingship, since + his father was not royal.' + target: 'Let''s think step by step. We refer to Wikipedia articles on prehistory + for help. Pacal built the temples as the funerary monument to legitimize his + kingship. The answer is (D).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_prehistory diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d18fc22626b952d18491b849cabc720706c17c9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_accounting.yaml @@ -0,0 +1,63 @@ +dataset_name: professional_accounting +description: The following are multiple choice questions (with answers) about professional + accounting. +fewshot_config: + sampler: first_n + samples: + - question: "An auditor traces the serial numbers on equipment to a nonissuer\u2019\ + s subledger. Which of the following management assertions is supported by this\ + \ test?\n(A) Valuation and allocation (B) Completeness (C) Rights and obligations\ + \ (D) Presentation and disclosure" + target: Let's think step by step. We refer to Wikipedia articles on accounting + for help. The completeness assertion is tested by tracing supporting documents + to the record entries. The answer is (B). + - question: 'One hundred years ago, your great-great-grandmother invested $100 at 5% + yearly interest. What is the investment worth today? + + (A) $13,000 (B) $600 (C) $15,000 (D) $28,000' + target: Let's think step by step. We refer to Wikipedia articles on accounting + for help. A $100 investment at 5% yearly interest is worth 100*(1.05)^100=13150 + after 100 years, which is around $13,000. The answer is (A). + - question: 'On January 1, year 1, Alpha Co. signed an annual maintenance agreement + with a software provider for $15,000 and the maintenance period begins on March + 1, year 1. Alpha also incurred $5,000 of costs on January 1, year 1, related + to software modification requests that will increase the functionality of the + software. Alpha depreciates and amortizes its computer and software assets over + five years using the straight-line method. What amount is the total expense + that Alpha should recognize related to the maintenance agreement and the software + modifications for the year ended December 31, year 1? + + (A) $5,000 (B) $13,500 (C) $16,000 (D) $20,000' + target: Let's think step by step. We refer to Wikipedia articles on accounting + for help. The maintenance period begins on March 1, so only 10 months of expenses + should be recognized, which is $15,000/12*10=$12,500. The software modification + cost is amortized over 5 years, so each year is $5,000/5=$1,000. So the total + expense is $12,500+$1,000=$13,500. The answer is (B). + - question: 'Krete is an unmarried taxpayer with income exclusively from wages. By + December 31, year 1, Krete''s employer has withheld $16,000 in federal income + taxes and Krete has made no estimated tax payments. On April 15, year 2, Krete + timely filed for an extension request to file her individual tax return, and + paid $300 of additional taxes. Krete''s year 1 tax liability was $16,500 when + she timely filed her return on April 30, year 2, and paid the remaining tax + liability balance. What amount would be subject to the penalty for underpayment + of estimated taxes? + + (A) $0 (B) $500 (C) $1,650 (D) $16,500' + target: Let's think step by step. We refer to Wikipedia articles on accounting + for help. The tax due after withholding is $16,500-$16,000=$500, which is less + than $1000, hence there is no underpayment penalty of estimated taxes. The answer + is (A). + - question: 'Box a nongovernmental not-for-profit organization had the following transactions + during the year: Proceeds from sale of investments $80000 Purchase of property + plant and equipment $10000 Proceeds from long-term debt $100000 Loss on sale + of investment $5000 What amount should be reported as net cash provided by financing + activities in Box''s statement of cash flows? + + (A) $70,000 (B) $75,000 (C) $80,000 (D) 100000' + target: 'Let''s think step by step. We refer to Wikipedia articles on accounting + for help. Among the four transactions, only Proceeds from long-term debt belongs + to the financing activities section of cashflow, hence the amount reported should + be $100000. The answer is (D).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_professional_accounting diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..307f8940bc445305fdbf00e89910cd5237a41312 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_law.yaml @@ -0,0 +1,122 @@ +dataset_name: professional_law +description: The following are multiple choice questions (with answers) about professional + law. +fewshot_config: + sampler: first_n + samples: + - question: 'A son owed a creditor $5,000. The son''s father contacted the creditor + and told him that he wanted to pay the son''s debt. The father signed a document + that stated the father would pay the son''s debt at a rate of $500 a month for + 10 months. The creditor made no written or oral commitment to forbear to sue + the son to collect the $5,000 debt, and the father made no oral or written request + for any such forbearance. For the next five months, the father made and the + creditor accepted the $500 monthly payments as agreed. During that period, the + creditor, in fact, did forbear to take any legal action against the son. However, + the father then informed the creditor that he would make no further payments + on the debt. Which of the following is the most persuasive argument that the + father is liable to the creditor under the terms of their agreement? + + (A) The father''s promise and the creditor''s reliance thereon, if proved, gave + rise to a valid claim by the creditor against the father based on the doctrine + of promissory estoppel. (B) Because it was foreseeable that the father''s promise + would induce the creditor to forbear taking any action against the son, such + forbearance was, as a matter of law, a bargained-for consideration for the father''s + promise. (C) The father''s five payments to the creditor totaling $2,500 manifested + a serious intent on the father''s part to be contractually bound, and such manifestation + is generally recognized as an effective substitute for consideration. (D) By + assuming the antecedent debt obligation that the son owed to the creditor, the + father became a surety whose promise to the creditor was enforceable, since + it was in writing and supported by adequate consideration. ' + target: Let's think step by step. We refer to Wikipedia articles on law for help. + The doctrine of promissory estoppel stops a person from going back on a promise + in contract law, hence option (A) should be the most persuasive argument. The + answer is (A). + - question: 'A state has recently enacted a statute prohibiting the disposal of any + nuclear wastes within the state. This law does not contravene or conflict with + any federal statutes. A man operates a company in the state that is engaged + in the disposal of nuclear wastes. Subsequent to the passage of the state statute, + the man, not yet aware of the new law, entered into contracts with many out-of-state + firms to dispose of their nuclear wastes in the state. On account of this new + law, however, the man will be unable to perform these contracts. Assume that + the man has standing to challenge this state law. Which of the following presents + his strongest constitutional grounds to challenge the state law prohibiting + the disposal of nuclear wastes within the state? + + (A) The commerce clause. (B) The equal protection clause of the Fourteenth Amendment. + (C) The privileges and immunities clause of Article IV, Section 2. (D) The contract + clause.' + target: Let's think step by step. We refer to Wikipedia articles on law for help. + The commerce clause states that Congress shall have the power to regulate commerce + with foreign Nations, and among the several States, and with the Indian Tribes. + The statute affects inter-state commerce which puts it into question. Hence + the man's strongest argument should be the commerce clause. The answer is (A). + - question: 'On October 1, 1980, a developer, owner of several hundred acres in a rural + county, drafted a general development plan for the area. The duly recorded plan + imposed elaborate limitations and restrictions upon the land in the plan, which + was to be developed as a residential district. The restrictions were to extend + to all persons acquiring any of the lots and to their heirs, assigns, and lessees. + It was further provided that all subsequent owners would be charged with due + notice of the restrictions. Among those restrictions in the general plan were + the following:(22) A franchise right is created in a strip of land 10 feet in + width along the rear of each lot for the use of public utility companies with + right of ingress and egress. (23) No house or structure of any kind shall be + built on the aforementioned strip of land running through the said blocks. In + 2000, a retiree purchased one of the lots, built a house, and erected a fence + in the rear of his property within the restricted area. In 2004, a teacher purchased + a lot adjacent to the retiree''s property and built a new house. Two years later, + a librarian purchased the lot that adjoined the teacher''s property. The three + deeds to those properties each contained references to the deed book where the + general plan was recorded. In 2008, the librarian began the construction of + a seven-foot post-and-rail fence along the line dividing his lot with the teacher''s, + and along the center of the area subject to the franchise right. Although the + teacher objected to its construction, the fence was completed. If the teacher + seeks a mandatory injunction to compel removal of the librarian''s fence, the + court will most likely + + (A) grant relief, because the fence was in violation of the easement restriction. + (B) grant relief, because the encroachment of the fence violated the restriction + in the original plan. (C) deny relief, because the teacher failed to enforce + the restriction against the retiree. (D) deny relief, because the fence would + not be construed as "a structure" within the terms of the restriction. ' + target: Let's think step by step. We refer to Wikipedia articles on law for help. + The restrictions in the original plan say no house or structure of any kind + shall be built on the aforementioned strip of land running through the said + blocks. Hence the court will most likely grant relief because the fence violated + the restriction in the original plan. The answer is (B). + - question: 'Judge took judicial notice of some facts at the beginning of the trial. + Which of the following is not an appropriate kind of fact for judicial notice? + + (A) Indisputable facts. (B) Facts that have been asserted by individual political + organizations. (C) Facts recognized to be true by common knowledge. (D) Facts + capable of scientific verification.' + target: Let's think step by step. We refer to Wikipedia articles on law for help. + Among the options, facts that have been asserted by individual political organizations + is not an appropriate kind of fact for judicial notice. The answer is (B). + - question: 'A state legislature has recently enacted a statute making it a misdemeanor + to curse or revile or use obscene or opprobrious language toward or in reference + to a police officer perfonning his duties. A student at a state university organized + a demonstration on campus to protest the war. The rally was attended by a group + of 50 students who shouted anti-war messages at cars passing by. To show his + contempt for the United States, the student sewed the American flag to the rear + of his jeans. When a police officer saw the flag sown on the student''s jeans, + he approached and told him to remove the flag or he would be placed under arrest. + The student became angered and shouted at the police officer, "Listen, you bastard, + I''ll wear this rag anywhere I please. " The student was subsequently placed + under arrest and charged with violating the state statute. The student subsequently + brings suit in state court challenging the constitutionality of the statute. + The strongest constitutional argument for the student is that + + (A) the statute is void for vagueness under the Fourteenth Amendment''s due + process clause. (B) the statute is invalid because it violates the petitioner''s + freedom of speech under the First Amendment. (C) the statute is an abridgment + of freedom of speech under the First Amendment because less restrictive means + are available for achieving the same purpose. (D) the statute is overbroad and + consequently invalid under the First and Fourteenth Amendments.' + target: 'Let''s think step by step. We refer to Wikipedia articles on law for + help. The Fourteenth Amendment further supports the First Amendment by establishing + a due process clause. Hence the strongest argument should be the statute is + overbroad and consequently invalid under the First and Fourteenth Amendments. + The answer is (D).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_professional_law diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4a5b0d995ae7d0897eaea3a6cfd0719562cdc3c9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_medicine.yaml @@ -0,0 +1,82 @@ +dataset_name: professional_medicine +description: The following are multiple choice questions (with answers) about professional + medicine. +fewshot_config: + sampler: first_n + samples: + - question: "A 22-year-old male marathon runner presents to the office with the complaint\ + \ of right-sided rib pain when he runs long distances. Physical examination\ + \ reveals normal heart and lung findings and an exhalation dysfunction at ribs\_\ + 4-5 on the right. Which of the following muscles or muscle groups will be most\ + \ useful in correcting this dysfunction utilizing a direct method?\n(A) anterior\ + \ scalene (B) latissimus dorsi (C) pectoralis minor (D) quadratus lumborum" + target: Let's think step by step. We refer to Wikipedia articles on medicine for + help. Among the options, only pectoralis minor muscle origins from the outer + surfaces of the 3rd to 5th ribs. The answer is (C). + - question: "A 36-year-old male presents to the office with a\_3-week\_history of low\ + \ back pain. He denies any recent trauma but says that he climbs in and out\ + \ of his truck numerous times a day for his job. Examination of the patient\ + \ in the prone position reveals a deep sacral sulcus on the left, a posterior\ + \ inferior lateral angle on the right, and a lumbosacral junction that springs\ + \ freely on compression. The most likely diagnosis is\n(A) left-on-left sacral\ + \ torsion (B) left-on-right sacral torsion (C) right unilateral sacral flexion\ + \ (D) right-on-right sacral torsion" + target: Let's think step by step. We refer to Wikipedia articles on medicine for + help. The deep sulcus on the left, a posterior ILA on the right, with a negative + spring test suggests a right-on-right sacral torsion. All other options have + a deep sulcus on the right. The answer is (D). + - question: "A 44-year-old man comes to the office because of a 3-day history of sore\ + \ throat, nonproductive cough, runny nose, and frontal headache. He says the\ + \ headache is worse in the morning and ibuprofen does provide some relief. He\ + \ has not had shortness of breath. Medical history is unremarkable. He takes\ + \ no medications other than the ibuprofen for pain. Vital signs are temperature\ + \ 37.4\xB0C (99.4\xB0F), pulse 88/min, respirations 18/min, and blood pressure\ + \ 120/84 mm Hg. Examination of the nares shows erythematous mucous membranes.\ + \ Examination of the throat shows erythema and follicular lymphoid hyperplasia\ + \ on the posterior oropharynx. There is no palpable cervical adenopathy. Lungs\ + \ are clear to auscultation. Which of the following is the most likely cause\ + \ of this patient's symptoms?\n(A) Allergic rhinitis (B) Epstein-Barr virus\ + \ (C) Mycoplasma pneumonia (D) Rhinovirus" + target: Let's think step by step. We refer to Wikipedia articles on medicine for + help. The symptoms, especially the headache, suggest that the most likely cause + is Rhinovirus. Epstein-Barr virus will cause swollen lymph nodes but there is + no palpable cervical adenopathy. Lungs are clear to auscultation suggests it's + not Mycoplasma pneumonia. The answer is (D). + - question: 'A previously healthy 32-year-old woman comes to the physician 8 months + after her husband was killed in a car crash. Since that time, she has had a + decreased appetite and difficulty falling asleep. She states that she is often + sad and cries frequently. She has been rechecking the door lock five times before + leaving her house and has to count exactly five pieces of toilet paper before + she uses it. She says that she has always been a perfectionist but these urges + and rituals are new. Pharmacotherapy should be targeted to which of the following + neurotransmitters? + + (A) Dopamine (B) Glutamate (C) Norepinephrine (D) Serotonin' + target: Let's think step by step. We refer to Wikipedia articles on medicine for + help. The patient feels sad and among the options, only Dopamine and Serotonin + can help increase positive emotions. Serotonin also affects digestion and metabolism, + which can help the patient's decreased appetite and sleep difficulty. The answer + is (D). + - question: "A 42-year-old man comes to the office for preoperative evaluation prior\ + \ to undergoing adrenalectomy scheduled in 2 weeks. One month ago, he received\ + \ care in the emergency department for pain over his right flank following a\ + \ motor vehicle collision. At that time, blood pressure was 160/100 mm Hg and\ + \ CT scan of the abdomen showed an incidental 10-cm left adrenal mass. Results\ + \ of laboratory studies, including complete blood count, serum electrolyte concentrations,\ + \ and liver function tests, were within the reference ranges. The patient otherwise\ + \ had been healthy and had never been told that he had elevated blood pressure.\ + \ He takes no medications. A follow-up visit in the office 2 weeks ago disclosed\ + \ elevated urinary normetanephrine and metanephrine and plasma aldosterone concentrations.\ + \ The patient was referred to a surgeon, who recommended the adrenalectomy.\ + \ Today, vital signs are temperature 36.6\xB0C (97.9\xB0F), pulse 100/min, respirations\ + \ 14/min, and blood pressure 170/95 mm Hg. Physical examination discloses no\ + \ significant findings. Initial preoperative preparation should include treatment\ + \ with which of the following?\n(A) Labetalol (B) A loading dose of potassium\ + \ chloride (C) Nifedipine (D) Phenoxybenzamine" + target: 'Let''s think step by step. We refer to Wikipedia articles on medicine + for help. The symptoms and the adrenal mass suggested pheochromocytoma, and + the blood pressure indicates hypertension. Phenoxybenzamine is used to treat + hypertension caused by pheochromocytoma. The answer is (D).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_professional_medicine diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60b5da683ff87d207304b894a5138a6c439a1c86 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_professional_psychology.yaml @@ -0,0 +1,62 @@ +dataset_name: professional_psychology +description: The following are multiple choice questions (with answers) about professional + psychology. +fewshot_config: + sampler: first_n + samples: + - question: 'In the construction of a multiple regression equation for purposes of + prediction, the optimal combination of measures is one in which the predictors + + (A) are uncorrelated with each other but are moderately correlated with the + criterion (B) have low correlations with each other and low correlations with + the criterion (C) are highly intercorrelated with each other and moderately + correlated with the criterion (D) have low correlations with the criterion bur + are moderately correlated with each other' + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. The basis of multiple regression is to assess the relationship between + one continuous variable and a set of independent variables. So the predictors + should be uncorrelated with each other but are moderately correlated with the + criterion. The answer is (A). + - question: 'There are three ways to measure the Central Tendency: the Mean, the Median + and the Mode. From your knowledge about them, what is the mode? + + (A) less sensitive to extreme scores than the mean (B) more useful for skewed + distributions (C) sensitive to extreme values and highly skewed distributions + (D) the most frequently occurring number' + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. The definition of mode is the most frequently occurring number. The + answer is (D). + - question: "Carl Jung believed that a client's transference:\n(A) is a fantasy that\ + \ distracts the client from reality. (B) represents \u201Cmixed feelings\u201D\ + \ toward the therapist. (C) \"is a form of \"\"acting out.\"\"\" (D) reflects\ + \ the client\u2019s personal and collective unconscious." + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. Transference is a phenomenon that a person's feelings are unconsciously + redirected, so it reflects the client's personal and collective unconscious. + The answer is (D). + - question: "In terms of Hofstede\u2019s (1980) five cultural dimensions, the United\ + \ States scores at the top of the scale on:\n(A) individualism. (B) individualism\ + \ and power distance. (C) power distance and masculinity. (D) uncertainty avoidance." + target: Let's think step by step. We refer to Wikipedia articles on psychology + for help. US scores highest on individualism among the five cultural dimensions. + The answer is (A). + - question: 'One of your therapy clients asks your advice about a good weight- reduction + program. You have investigated the programs in the community and are enrolled + in the one you consider the best. This program offers a $50 bonus to its patrons + for each new person they bring into the program. Under these circumstances, + your most appropriate response would be to + + (A) tell your client the pros and cons of each program you know about except + for the one in which you are enrolled (B) recommend to your client the program + in which you are enrolled and explain the $50 bonus you will receive (C) recommend + to your client the program in which you are enrolled and offer to have the $50 + bonus credited to your client''s account in the program (D) tell your client + the pros and cons of each program you know about, but do not claim the $50 bonus + if your client enrolls in your program' + target: 'Let''s think step by step. We refer to Wikipedia articles on psychology + for help. Based on the circumstances, you should tell your client about the + pros and cons of each program, but it would be inappropriate to receive the + bonus, so you should not claim the $50 bonus. The answer is (D).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_professional_psychology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_public_relations.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe384b1e2b7d19c216f8344d5c249f2c16dc723b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_public_relations.yaml @@ -0,0 +1,55 @@ +dataset_name: public_relations +description: The following are multiple choice questions (with answers) about public + relations. +fewshot_config: + sampler: first_n + samples: + - question: 'Earth Hour was a campaign launched by which organization? + + (A) Greenpeace (B) The UN (C) Oxfam (D) World Wildlife Fund' + target: Let's think step by step. We refer to Wikipedia articles on public relations + for help. Earth Hour is a worldwide movement oragnized launched by the World + Wildlife Fund. The answer is (D). + - question: 'In issues management, what is the most proactive approach to addressing + negative or misleading information posted online about your organization? + + (A) Buy domain names that could be used by opposition groups. (B) Post anonymous + comments on blogs to combat this information. (C) Prepare a news release that + discredits the inaccurate information. (D) Make policy changes to address complaints + highlighted on these sites.' + target: Let's think step by step. We refer to Wikipedia articles on public relations + for help. In issues management, the most proactive approach to addressing negative + or misleading information posted online is to make policy changes to address + complaints highlighted on those sites. The answer is (D). + - question: 'At which stage in the planning process would a situation analysis be carried + out? + + (A) Defining the program (B) Planning the program (C) Taking action and implementing + ideas (D) Evaluation of the program' + target: Let's think step by step. We refer to Wikipedia articles on public relations + for help. Situation analyses are typically carried out during the planning process + stage of defining the program. The answer is (A). + - question: 'Which of these statements is true of the Vatican in 2010 at the time of + the accusations of child abuse cover-ups? + + (A) There was a coordinated media response. (B) Consistent messages were communicated. + (C) Criticisms were taken as attacks on the Catholic Church. (D) The credibility + of the Vatican was upheld.' + target: Let's think step by step. We refer to Wikipedia articles on public relations + for help. In 2010 when there were accusations of child abuse cover-ups, the + Vatican took those criticisms as attacks on the Catholic Church. The answer + is (C). + - question: 'What should a public relations media practitioner do if she does not know + the answer to a reporter''s question? + + (A) Give the reporter other information she is certain is correct. (B) Say that + the information is ''off the record'' and will be disseminated later. (C) Say + ''I don''t know'' and promise to provide the information later. (D) Say ''no + comment,'' rather than appear uninformed.' + target: 'Let''s think step by step. We refer to Wikipedia articles on public relations + for help. If a public relations media practitioner does not know the answer + to a reporter''s question, they should say ''I don''t know'' and offer to provide + the information later. The answer is (C).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_public_relations diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_security_studies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_security_studies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b37e35b3bd4fefa0ca040f0d59ff2fcae156fb45 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_security_studies.yaml @@ -0,0 +1,104 @@ +dataset_name: security_studies +description: The following are multiple choice questions (with answers) about security + studies. +fewshot_config: + sampler: first_n + samples: + - question: 'What are the frameworks of analysis within which terrorism has been considered + (as of 2020)? + + (A) Competition between larger nations has resulted in some countries actively + supporting terrorist groups to undermine the strength of rival states. Terrorist + networks are extended patronage clubs maintained and paid for by their donor + states and are conceptualised as being like state actors, to be dealt with using + military force. (B) Globalization has enabled the internationalization of terrorist + activities by opening up their operational space, although coordination is still + managed from a geographical base. This suggests that terrorist groups are nationally + structured which means that terrorism cannot be considered in terms of a war + to be defeated militarily without having serious implications on the indigenous + population. (C) Terrorism can be viewed as a problem to be resolved by military + means (war on terrorism), by normal police techniques (terrorism as crime), + or as a medical problem with underlying causes and symptoms (terrorism as disease). + (D) Terrorism is viewed as a criminal problem. The criminalization of terrorism + has two important implications. Firstly, it suggests that terrorism can be eradicated + - terrorists can be caught and brought to trial by normal judicial proceedings + thereby removing the threat from society - and secondly, it suggests that preventative + crime techniques are applicable to prevent its development.' + target: "Let's think step by step. We refer to Wikipedia articles on security\ + \ studies for help. (A) is wrong because it is not competition between larger\ + \ nations that causes terrorism. \n(B) is wrong because globalization is not\ + \ the cause of terrorism.\n(C) is correct because the US undertook the war on\ + \ terrorism. \n(D) is wrong because preventative crime techniques will likely\ + \ not end terrorism. The answer is (C)." + - question: 'Which of the following is the best lens through which to investigate the + role of child soldiers? + + (A) Child soldiers are victims of combat that need re-education and rehabilitation. + (B) Children and their mothers are not active subjects in warfare and are best + considered as subjects in the private sphere. (C) Children are most often innocent + bystanders in war and are best used as signifiers of peace. (D) Children have + political subjecthood that is missed when they are considered as passive victims + of warfare.' + target: Let's think step by step. We refer to Wikipedia articles on security studies + for help. Child soliders as a political topic can be missed when they are considered + passive victims of warfare. The answer is (D). + - question: 'How can we best describe the relationship between the state-centric approach + and the concept of human security? + + (A) There are such wide divisions within the human security framework regarding + the nature of threats and referent objects that no widely applicable comparisons + between state-centric approaches and human security can be drawn. (B) By adopting + the framework of human security, the limitations of the realist state-centric + approach become evident. Whilst human security defines the referent object as + the person or population, state-centric approaches prioritise the security of + the state, de-prioritizing the pursuit of human security. (C) The state-centric + approach to security is a faction of human security, usually defined within + the broad school of human security. By being state-centric this approach prioritises + the individual as the referent object in security studies. (D) Both the state-centric + and human-centric approaches to security are mutually exclusive and offer a + sufficient analytic framework with which to understand the international security + system. It is therefore the role of security analysts to determine which of + these substantial concepts is correct, and which should be discarded.' + target: Let's think step by step. We refer to Wikipedia articles on security studies + for help. Human security focuses on a person or population whereas state-centric + approaches focus on the state while deprioritizing human security. The answer + is (B). + - question: 'In order to become securitized, a threat must be presented in which of + these ways? + + (A) As an existential threat that requires immediate and extraordinary action, + posing a threat to the survival of the state or to societal security. (B) As + requiring immediate and extraordinary action by the state, threatening the survival + of a referent object and therefore warranting the use of measures not normally + employed in the political realm. (C) As an urgent threat to the survival of + the referent object, so serious that it legitimises the employment of extraordinary + action in response. (D) As an urgent threat to the survival of the audience + that requires extraordinary or emergency measures.' + target: Let's think step by step. We refer to Wikipedia articles on security studies + for help. To be securitized, a threat must be an urgent threat to the survival + of the referent object. The answer is (C). + - question: 'What distinguishes coercive diplomacy from military force? + + (A) Compellence is another term for coercive diplomacy, but covering a narrower + set of criteria; compellence covers those threats aimed at initiating adversary + action. A threat to coerce a state to give up part of its territory would count + as coercive diplomacy, as long as that threat proactively initiates action before + reactive diplomacy is taken. (B) Coercive diplomacy constitutes the threats + of limited force to induce adversary''s incentive to comply with the coercer''s + demands. It is an influence strategy that is intended to obtain compliance: + the use of force to defeat an opponent first does not count. It leaves an element + of choice with the target to comply, or to continue. (C) Military force, or + the threat of military force, utilises fear to achieve strategic objectives. + Coercive diplomacy is differentiated from this approach, because it does not + use fear as a tool for coercing an adversary. (D) Coercive diplomacy is employed + to use force but to limit its effects on the international community. Coercive + diplomacy is an aggressive strategy that is intended to obtain compliance through + defeat. It does not leave an element of choice with the target, the target either + being forced to comply or engage in conflict. It seeks to control by imposing + compliance by removing any opportunity for negotiation or concession.' + target: 'Let''s think step by step. We refer to Wikipedia articles on security + studies for help. Coercive diplomacy uses the threat of force to induce the + opponent to comply with demands. The answer is (B).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_security_studies diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_sociology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4229d64785ded8673d421a9fb1571d0cce705a93 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_sociology.yaml @@ -0,0 +1,58 @@ +dataset_name: sociology +description: The following are multiple choice questions (with answers) about sociology. +fewshot_config: + sampler: first_n + samples: + - question: 'Which of the following is not a problem associated with official statistics + on strike action? + + (A) most strikes go unnoticed by employers and the mass media (B) not all industrial + disputes will be reported by the employer (C) the definition of strikes excludes + those that involve fewer than ten workers or last less than one day (D) it is + hard to compare strikes that were measured in different ways' + target: Let's think step by step. We refer to Wikipedia articles on sociology + for help. Official statistics on strike action can be problematic because not + all industrial disputes will be reported by employers, the definition of strikes + excludes those that involves fewer than ten workers or last less than one day, + and it is hard to compare strikes that were measured in different ways. Thus, + (A) is not a problem associated with official statistics on strike action. The + answer is (A). + - question: 'What does Berger (1963) describe as a metaphor for social reality? + + (A) a fairground ride (B) a circus (C) a puppet theatre (D) a ballet' + target: Let's think step by step. We refer to Wikipedia articles on sociology + for help. Berger describes social reality using the metaphor of a puppet theatre. + The answer is (C). + - question: 'The term ''hegemony'' refers to: + + (A) the tendency for the working class not to realize their own interests (B) + a dominant ideology that legitimates economic, political and cultural power + (C) a form of dual consciousness based on ideology and everyday experiences + (D) a mode of payment given for outstanding topiary' + target: Let's think step by step. We refer to Wikipedia articles on sociology + for help. Hegemony refers to a dominant ideology that legitimates economic, + policital, and cultural power. The answer is (B). + - question: 'The shift from ''civil religion'' to ''common religion'' means that: + + (A) the increasing bureaucracy of the state has made religion only a marginal + part of our lives (B) despite the weakening of traditional authority, our everyday + lives and ''common sense'' remain shaped by religious beliefs and values (C) + religious participation in collective worship may have declined, but people + still practise their faiths in private (D) people are much more likely to discuss + their religious beliefs in public, informal settings' + target: Let's think step by step. We refer to Wikipedia articles on sociology + for help. The shift from civil religion to common religion means that despite + the weakening of traditional authority, our everyday lives and common sense + remain shaped by religious beliefs and values. The answer is (B). + - question: 'Which of the following did the post-war welfare state of 1948 not aim + to provide: + + (A) free health care and education for all (B) a minimum wage (C) full employment + (D) universal welfare' + target: 'Let''s think step by step. We refer to Wikipedia articles on sociology + for help. The post-war welfare state of 1948 aimed to provide free healthcare + and education, full employment, and universal welfare. But it did not aim to + provide a minimum wage. The answer is (B).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_sociology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_us_foreign_policy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_us_foreign_policy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc86b7c88fa1b62d2f12deaa16394f43fc722225 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_us_foreign_policy.yaml @@ -0,0 +1,56 @@ +dataset_name: us_foreign_policy +description: The following are multiple choice questions (with answers) about us foreign + policy. +fewshot_config: + sampler: first_n + samples: + - question: 'How did Donald Trump attack globalization in the 2016 campaign? + + (A) Globalization had made men like him too rich (B) Globalization only benefited + certain American states, such as New York (C) Liberal elites had encouraged + globalization, while ''ordinary Americans'' lost jobs because of it (D) Globalization + encouraged damaging trade wars' + target: Let's think step by step. We refer to Wikipedia articles on us foreign + policy for help. Trump attacked globalization because he believed ordinary Americans + lost jobs due to it, and so he wanted to blame liberals who had encouraged it. + The answer is (C). + - question: 'How did NSC-68 change U.S. strategy? + + (A) It globalized containment. (B) It militarized containment. (C) It called + for the development of the hydrogen bomb. (D) All of the above' + target: Let's think step by step. We refer to Wikipedia articles on us foreign + policy for help. NSC-68 outlined a variety of courses of action, including globalization + of containment, militarization of contaiment, and the development of the hydrogen + bomb. The answer is (D). + - question: 'How do Defensive Realism and Offensive Realism differ in their explanation + of state behaviour? + + (A) Defensive realists place greater emphasis on the role of international institutions + (B) Defensive realists place less emphasis on geographical factors (C) Offensive + realists give more priority to the national interest than Defensive realists. + (D) Defensive realists believe states are security maximizers, while Offensive + realists believe states to be power maximizers' + target: Let's think step by step. We refer to Wikipedia articles on us foreign + policy for help. While defensive realism advocates that states are security + maximizers, offensive realists think of states as power maximizers. The answer + is (D). + - question: 'The realm of policy decisions concerned primarily with relations between + the United States and the rest of the world is known as + + (A) terrorism policy. (B) economic policy. (C) foreign policy. (D) international + policy.' + target: Let's think step by step. We refer to Wikipedia articles on us foreign + policy for help. The topic of policy decisions concerns with relations between + the US and the rest of the world is known as foreign policy. The answer is (C). + - question: 'How did the 2008 financial crisis affect America''s international reputation? + + (A) It damaged support for the US model of political economy and capitalism + (B) It created anger at the United States for exaggerating the crisis (C) It + increased support for American global leadership under President Obama (D) It + reduced global use of the US dollar' + target: 'Let''s think step by step. We refer to Wikipedia articles on us foreign + policy for help. The 2008 financial crisis damanged the international reputation + of the American model of political economy and capitalism. The answer is (A).' +tag: mmlu_flan_cot_fewshot_social_sciences +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_us_foreign_policy diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_virology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e70f0ee24cc946a1cfbc51cb87d4fde20d8171c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_virology.yaml @@ -0,0 +1,45 @@ +dataset_name: virology +description: The following are multiple choice questions (with answers) about virology. +fewshot_config: + sampler: first_n + samples: + - question: 'The median survival time to AIDS and death was established by following: + + (A) Seroprevalent HIV-infected individuals (B) Seronegatives (C) Seroconverters + (D) High-risk seronegatives' + target: Let's think step by step. We refer to Wikipedia articles on virology for + help. The median survival time to AIDS and death was established as a result + of the development of seroconverters. The answer is (C). + - question: 'Which of the following is a morphological characteristic of the paramyxoviruses. + + (A) Fragile viruses often visualised with RNA spewing from the inside (B) Elongate + viruses (C) Icosahedral viruses with envelope (D) Very large viruses' + target: Let's think step by step. We refer to Wikipedia articles on virology for + help. Paramyxoviruses are fragile viruses often visualised with RNA spewing + from the inside. The answer is (A). + - question: 'The most important goal of a behavioral intervention is: + + (A) Change in behavior (B) Comprehensive coverage (C) Effective use of behavioral + theory (D) Sustained behavior change' + target: Let's think step by step. We refer to Wikipedia articles on virology for + help. The prim goal of a behavioral intervention is to cause sustained behavior + change. The answer is (D). + - question: 'A key factor facilitating the application of nested case-control studies + from the MACS was: + + (A) Data collection (B) Establishment of a repository of biologic specimens + (C) Participant interest (D) Administration of the questionnaire by staff' + target: Let's think step by step. We refer to Wikipedia articles on virology for + help. The Multicenter AIDS Cohort Study's use of nested case-control studies + was facilitated by the establishment of a repository of biologic specimens. + The answer is (B). + - question: 'Why are parvoviruses a highly impactful parasite? + + (A) Because they have no nucleic acid (B) They require a helper virus (C) Only + replicate in dividing cells (D) Can integrate into host chromosomes' + target: 'Let''s think step by step. We refer to Wikipedia articles on virology + for help. Paroviruses are highly impactful because they do not have nucleic + acid. The answer is (A).' +tag: mmlu_flan_cot_fewshot_other +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_virology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_world_religions.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_world_religions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41502cc7a3a1318b7c6a0f2ac16cda86dda08486 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_fewshot/mmlu_world_religions.yaml @@ -0,0 +1,42 @@ +dataset_name: world_religions +description: The following are multiple choice questions (with answers) about world + religions. +fewshot_config: + sampler: first_n + samples: + - question: 'How can the Upanishads be characterized? + + (A) Ritual texts (B) Philosophical texts (C) Hymns (D) Origin stories' + target: Let's think step by step. We refer to Wikipedia articles on world religions + for help. The Upanishads are the most recent part of Vedas (the oldest scriptures + in Hinduism) and supplied the basis of later Hindu philosophy. So they are philosophical + texts. The answer is (B). + - question: 'What is the Second Gem in Buddhism? + + (A) The Dharma (B) The Sangha (C) The Buddha (D) The Bodhisattva' + target: Let's think step by step. We refer to Wikipedia articles on world religions + for help. The Second Gem in Buddhism is The Dharma. The answer is (A). + - question: 'Which Japanese government promoted a kind of national cult based on the + emperor and his associations with kami? + + (A) Honen (B) Tanaka (C) Tokugawa (D) Meiji' + target: Let's think step by step. We refer to Wikipedia articles on world religions + for help. The promotion of a national cult based on the emperor and his associations + with Kami happened during the reign of Emperor Meiji (1852-1912). The answer + is (D). + - question: 'In which dynasty was the "Mandate of Heaven" developed to legitimatize + the new rulers? + + (A) Shang (B) Zhou (C) Han (D) Xia' + target: Let's think step by step. We refer to Wikipedia articles on world religions + for help. The "Mandate of Heaven" was developed as an ancient Chinese philosophical + concept during the Zhou Dynasty (1046-256 BCE). The answer is (B). + - question: 'What is the sign of the covenant for Jewish males? + + (A) The rainbow (B) Circumcision (C) A son (D) Bar mitzvah' + target: 'Let''s think step by step. We refer to Wikipedia articles on world religions + for help. In Judaism, the most distinctive sign of the covenant is circumcision + (brit milah). The answer is (B).' +tag: mmlu_flan_cot_fewshot_humanities +include: _mmlu_flan_cot_fewshot_template_yaml +task: mmlu_flan_cot_fewshot_world_religions diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..745a892568bd84b38252e20bbc9a0bea73ddb1db --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu.yaml @@ -0,0 +1,32 @@ +group: mmlu_flan_cot_zeroshot +group_alias: mmlu (flan style, zeroshot cot) +task: + - group: stem + task: + - mmlu_flan_cot_zeroshot_stem + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: other + task: + - mmlu_flan_cot_zeroshot_other + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: social sciences + task: + - mmlu_flan_cot_zeroshot_social_sciences + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: humanities + task: + - mmlu_flan_cot_zeroshot_humanities + aggregate_metric_list: + - metric: acc + weight_by_size: True +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu_flan_cot_zeroshot_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu_flan_cot_zeroshot_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..7588b67e1905dc7eae3790c33a16de460c465f67 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu_flan_cot_zeroshot_template_yaml @@ -0,0 +1,38 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +validation_split: validation +fewshot_split: dev +output_type: generate_until +doc_to_text: "Q: {{question.strip()}}\n(A) {{choices[0]}} (B) {{choices[1]}} (C) {{choices[2]}} (D) {{choices[3]}}\nA: Let's think step by step." +doc_to_target: "{{['(A)', '(B)', '(C)', '(D)'][answer]}}" +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))" + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "multi_choice_regex" + group_select: -1 + ignore_case: true + ignore_punctuation: true + regex_pattern: "(\\([A-Z]\\))" + - function: "take_first" +generation_kwargs: + until: + - "" + - "Q:" + - "<|im_end|>" + do_sample: false + temperature: 0.0 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 3.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e885b818eae4bbc87374c756b68ecd11e44bd69 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_abstract_algebra.yaml @@ -0,0 +1,6 @@ +"dataset_name": "abstract_algebra" +"description": "The following are multiple choice questions (with answers) about abstract\ + \ algebra.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_abstract_algebra" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_anatomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f17410a7cc0869223730328f55803d8d424e930 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_anatomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "anatomy" +"description": "The following are multiple choice questions (with answers) about anatomy.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_anatomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_astronomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5b821f97642ad5987244a0ac4c9988c2fca3857 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_astronomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "astronomy" +"description": "The following are multiple choice questions (with answers) about astronomy.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_astronomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b26c679e26b6bd04d77eb5e0bb2ebaddcc515561 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_business_ethics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "business_ethics" +"description": "The following are multiple choice questions (with answers) about business\ + \ ethics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_business_ethics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_clinical_knowledge.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_clinical_knowledge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c0e9d17db10f4e69d1c44d5a127f2bbe1f4e279 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_clinical_knowledge.yaml @@ -0,0 +1,6 @@ +"dataset_name": "clinical_knowledge" +"description": "The following are multiple choice questions (with answers) about clinical\ + \ knowledge.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_clinical_knowledge" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de020f4eaca7fdeb650688f034ee3b5d89490ddc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_biology" +"description": "The following are multiple choice questions (with answers) about college\ + \ biology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8e5bbcf76b9fb3ad012511b213ffbbd554cd58d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_chemistry" +"description": "The following are multiple choice questions (with answers) about college\ + \ chemistry.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04b5e750949984abcd7889be80485e52c97dba9f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_computer_science" +"description": "The following are multiple choice questions (with answers) about college\ + \ computer science.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81c59cc2c20f340a76ed3d945e976ce3c832815c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_mathematics" +"description": "The following are multiple choice questions (with answers) about college\ + \ mathematics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0450a068f4b763629e463d9882e4a3e99f86d726 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_medicine" +"description": "The following are multiple choice questions (with answers) about college\ + \ medicine.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82c2bb2ab586be2346237a6aa8b2ea9fd9170c97 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_physics" +"description": "The following are multiple choice questions (with answers) about college\ + \ physics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_computer_security.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78216a44778fa0f9f1e057d5dc45b998fd5e87fc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_computer_security.yaml @@ -0,0 +1,6 @@ +"dataset_name": "computer_security" +"description": "The following are multiple choice questions (with answers) about computer\ + \ security.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_computer_security" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52304bdf8eeac624c63331b259255a98866dc2ac --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_conceptual_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "conceptual_physics" +"description": "The following are multiple choice questions (with answers) about conceptual\ + \ physics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_conceptual_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_econometrics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5be81c442710f91ad3e1ca6a0651105b2f14e24 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_econometrics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "econometrics" +"description": "The following are multiple choice questions (with answers) about econometrics.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_econometrics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..934a1a20a69d987904fe9c8b605c93e4ed149309 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_electrical_engineering.yaml @@ -0,0 +1,6 @@ +"dataset_name": "electrical_engineering" +"description": "The following are multiple choice questions (with answers) about electrical\ + \ engineering.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_electrical_engineering" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96ec81d6a8716ad60a4b3215faa42f3c3b1396d7 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_elementary_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "elementary_mathematics" +"description": "The following are multiple choice questions (with answers) about elementary\ + \ mathematics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_elementary_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..915c96de78b68bdd2b8b8cbb26f2f8ec0ae24167 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_formal_logic.yaml @@ -0,0 +1,6 @@ +"dataset_name": "formal_logic" +"description": "The following are multiple choice questions (with answers) about formal\ + \ logic.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_formal_logic" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_global_facts.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a1f7491590b80e784360ceb72619efe4d9568f1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_global_facts.yaml @@ -0,0 +1,6 @@ +"dataset_name": "global_facts" +"description": "The following are multiple choice questions (with answers) about global\ + \ facts.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_global_facts" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c4043d9bd7e6a38d702afa7ccb4028e98001445 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_biology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school biology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5aee89159d40e4f7c788cf670d9fa2e405d32c75 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_chemistry" +"description": "The following are multiple choice questions (with answers) about high\ + \ school chemistry.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb3eb2134bf8e3e8b8e81f29432db3e81b5f2fcf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_computer_science" +"description": "The following are multiple choice questions (with answers) about high\ + \ school computer science.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6fc261e8fe114ffc9d7be99110d659704018f159 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_european_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_european_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school european history.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_european_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..baabc83a9e25b700600fe516d9a84833c32f4f29 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_geography.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_geography" +"description": "The following are multiple choice questions (with answers) about high\ + \ school geography.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_geography" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41365c509da451280527720e651d5793d1b83960 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_government_and_politics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school government and politics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_government_and_politics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..05e62fa85cb3fdf871ec246de43d32c7a5209db1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_macroeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school macroeconomics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_macroeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c9a9ca3b3840ee7169b59a53cec4c595c783cd4e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_mathematics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school mathematics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2fb8639003555bdca712f3dc49ed6e463158be42 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_microeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school microeconomics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_microeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c149ef083a87f6d3eb412f9e3fb2fbd131ec4c0e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_physics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school physics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..999f9be74e2bc278a068c344030ae27f3b2c3006 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_psychology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school psychology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_statistics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_statistics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a0f905569c82f31ec76a75505bfae64c28d72640 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_statistics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_statistics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school statistics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_statistics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d09cdcaa3b268d599e055f82c92779d4ecd2bcb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_us_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_us_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school us history.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_us_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28a63b1b9106219486b5487b24396baf44179276 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_world_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_world_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school world history.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_world_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_aging.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a71bfc38aab72f17a01e3da11fc037ce28ef033 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_aging.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_aging" +"description": "The following are multiple choice questions (with answers) about human\ + \ aging.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_human_aging" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa9b895b7331b051385a31165c725c2ef976db69 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_sexuality.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_sexuality" +"description": "The following are multiple choice questions (with answers) about human\ + \ sexuality.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_human_sexuality" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_international_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_international_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33766a464fa475a012d229c194c93fffb84942b6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_international_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "international_law" +"description": "The following are multiple choice questions (with answers) about international\ + \ law.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_international_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..642e6ce4f34992cb5be8b840ea481c7a389d9ce8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_jurisprudence.yaml @@ -0,0 +1,6 @@ +"dataset_name": "jurisprudence" +"description": "The following are multiple choice questions (with answers) about jurisprudence.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_jurisprudence" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12594895469fbf0644e1908e4299f93f417703e8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_logical_fallacies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "logical_fallacies" +"description": "The following are multiple choice questions (with answers) about logical\ + \ fallacies.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_logical_fallacies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c27feea94ce017e35bcd453d6cbf5c4db5b3334 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_machine_learning.yaml @@ -0,0 +1,6 @@ +"dataset_name": "machine_learning" +"description": "The following are multiple choice questions (with answers) about machine\ + \ learning.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_machine_learning" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_management.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1a13763a2bd796821efa251071359ce0acbf1cf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_management.yaml @@ -0,0 +1,6 @@ +"dataset_name": "management" +"description": "The following are multiple choice questions (with answers) about management.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_management" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_marketing.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0fe6e44b7fe464396e85a53f70831bbb48ff8ece --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_marketing.yaml @@ -0,0 +1,6 @@ +"dataset_name": "marketing" +"description": "The following are multiple choice questions (with answers) about marketing.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_marketing" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..813b6a3fe90413bd35a11f82624df600d8bf682b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_medical_genetics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "medical_genetics" +"description": "The following are multiple choice questions (with answers) about medical\ + \ genetics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_medical_genetics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2a95e892a8e6d357e6a9f771272d06422b14d1a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_miscellaneous.yaml @@ -0,0 +1,6 @@ +"dataset_name": "miscellaneous" +"description": "The following are multiple choice questions (with answers) about miscellaneous.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_miscellaneous" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6a76a2a7930589f3603fa070e974116b4996e96 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_disputes.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_disputes" +"description": "The following are multiple choice questions (with answers) about moral\ + \ disputes.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_moral_disputes" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a5f8c4e6f144dcb4c0eb6881b095434c76105bb6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_scenarios.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_scenarios" +"description": "The following are multiple choice questions (with answers) about moral\ + \ scenarios.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_moral_scenarios" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_nutrition.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0f144cb44e5218d3a70193fddca2a2883e6b1b8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_nutrition.yaml @@ -0,0 +1,6 @@ +"dataset_name": "nutrition" +"description": "The following are multiple choice questions (with answers) about nutrition.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_nutrition" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4e4c0c4b6ccd34ebf4ff1133d0e26ddd8dc90d9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_philosophy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "philosophy" +"description": "The following are multiple choice questions (with answers) about philosophy.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_philosophy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_prehistory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_prehistory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9db801a6a9f2d911e2bdbbe0084fd235c7572776 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_prehistory.yaml @@ -0,0 +1,6 @@ +"dataset_name": "prehistory" +"description": "The following are multiple choice questions (with answers) about prehistory.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_prehistory" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e94bef0581e5290ff4790b5d48863a198a904879 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_accounting.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_accounting" +"description": "The following are multiple choice questions (with answers) about professional\ + \ accounting.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_professional_accounting" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25239d9a35941d49797c15986cc43213b0ec74d6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_law" +"description": "The following are multiple choice questions (with answers) about professional\ + \ law.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_professional_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f961bff89745dd8999c2ee497bdf9a7df88e04f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_medicine" +"description": "The following are multiple choice questions (with answers) about professional\ + \ medicine.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_professional_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48758ef76eaf72e4236a8569e041ea03e6626e67 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_psychology" +"description": "The following are multiple choice questions (with answers) about professional\ + \ psychology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_professional_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_public_relations.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62a56a4478bf9eafbcf1a8034abfeea6240e99ca --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_public_relations.yaml @@ -0,0 +1,6 @@ +"dataset_name": "public_relations" +"description": "The following are multiple choice questions (with answers) about public\ + \ relations.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_public_relations" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_security_studies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_security_studies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..062f49630e82b66be1ea0e75ed9fe73c8d635215 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_security_studies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "security_studies" +"description": "The following are multiple choice questions (with answers) about security\ + \ studies.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_security_studies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_sociology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36b4711831ef6fafde0915178e28513692f9c8d5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_sociology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "sociology" +"description": "The following are multiple choice questions (with answers) about sociology.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_sociology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_us_foreign_policy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_us_foreign_policy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c4afb8f84a193442cd98a856ada7e43f1515cbce --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_us_foreign_policy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "us_foreign_policy" +"description": "The following are multiple choice questions (with answers) about us\ + \ foreign policy.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_us_foreign_policy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_virology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8e427612f45461a5d873edbafb3d6e0eba4e9f1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_virology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "virology" +"description": "The following are multiple choice questions (with answers) about virology.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_virology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_world_religions.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_world_religions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0eb04f31f0baaf6ac0f358de2897d5267e1a4357 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_world_religions.yaml @@ -0,0 +1,6 @@ +"dataset_name": "world_religions" +"description": "The following are multiple choice questions (with answers) about world\ + \ religions.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_world_religions" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72246935de8cf0cf8b256fd1e6c87dfbbb90a2ad --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py @@ -0,0 +1,112 @@ +import re +import sys +import unicodedata + +from lm_eval.filters.extraction import RegexFilter + + +class MultiChoiceRegexFilter(RegexFilter): + """ """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, + ) -> None: + """ + regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure + - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. + - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. + group_select: Selects the (group_select)th match from the findall result. + ignore_case: Ignores the case during step 1 matching + ignore_punctuation: Remove the punctuation during step 1 matching + regexes_to_ignore: Remove these regexes during step 1 matching + """ + super().__init__(regex_pattern, group_select, fallback) + self.ignore_case = ignore_case + self.ignore_punctuation = ignore_punctuation + self.regexes_to_ignore = regexes_to_ignore + + def apply(self, resps, docs): + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + + def find_match(regex, resp, convert_dict={}): + match = regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m][0] + match = match.strip() + if match and match in convert_dict: + match = convert_dict[match] + return match + + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) + + def filter_ignores(st): + if self.regexes_to_ignore is not None: + for s in self.regexes_to_ignore: + st = re.sub(s, "", st) + + if self.ignore_case: + st = st.lower() + + if self.ignore_punctuation: + # https://stackoverflow.com/a/266162 + st = st.translate(punct_tbl) + return st + + filtered_resps = [] + + for r, doc in zip(resps, docs): + fallback_regexes = [] + choice_to_alpha = {} + next_alpha = "A" + + without_paren_fallback_regexes = [] + without_paren_to_target = {} + + choices = doc["choices"] + for c in choices: + m = filter_ignores(c.strip()) + fallback_regexes.append(f"{re.escape(m)}") + choice_to_alpha[m] = f"({next_alpha})" + + without_paren_fallback_regexes.append(next_alpha) + without_paren_to_target[next_alpha] = f"({next_alpha})" + + next_alpha = chr(ord(next_alpha) + 1) + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + f":[\s]*({without_paren_fallback_regex})" + ) + + filtered = [] + for resp in r: + match = find_match(self.regex, resp) + if not match: + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) + if not match: + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) + if not match: + match = self.fallback + filtered.append(match) + filtered_resps.append(filtered) + + return filtered_resps diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/_mmlu.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14465ad6e5c5434974832399ea95903b59e4eaf5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/_mmlu.yaml @@ -0,0 +1,32 @@ +group: mmlu_flan_n_shot_generative +group_alias: mmlu (flan style, generative) +task: + - group: stem + task: + - mmlu_flan_n_shot_generative_stem + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: other + task: + - mmlu_flan_n_shot_generative_other + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: social sciences + task: + - mmlu_flan_n_shot_generative_social_sciences + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: humanities + task: + - mmlu_flan_n_shot_generative_humanities + aggregate_metric_list: + - metric: acc + weight_by_size: True +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/_mmlu_flan_generative_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/_mmlu_flan_generative_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..a38a06969e2649d2fc0cf8e2be3efc60d91b3076 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/_mmlu_flan_generative_template_yaml @@ -0,0 +1,34 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +test_split: test +fewshot_split: dev +fewshot_config: + sampler: first_n +output_type: generate_until +doc_to_text: "Q: {{question.strip()}}\n(A) {{choices[0]}} (B) {{choices[1]}} (C) {{choices[2]}} (D) {{choices[3]}}\nA:" +doc_to_target: "{{['(A)', '(B)', '(C)', '(D)'][answer]}}" +filter_list: + - name: "strict-match" + filter: + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "multi_choice_regex" + group_select: 0 + regex_pattern: "(\\([A-Z]\\))" + ignore_case: true + ignore_punctuation: true + - function: "take_first" +generation_kwargs: + until: + - "" + - "Q:" + - "<|im_end|>" + - "\n" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +metadata: + version: 3.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e602ee8100ed612d89385532ea30004c3033c35 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_abstract_algebra.yaml @@ -0,0 +1,6 @@ +"dataset_name": "abstract_algebra" +"description": "The following are multiple choice questions (with answers) about abstract\ + \ algebra.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_abstract_algebra" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_anatomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa12cc8ef35b19f3b81dcc58a0107d424a3580cc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_anatomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "anatomy" +"description": "The following are multiple choice questions (with answers) about anatomy.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_anatomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_astronomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4178654e0e6e7a053839319c7936967133cf756 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_astronomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "astronomy" +"description": "The following are multiple choice questions (with answers) about astronomy.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_astronomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4447d276b066ddec93b8f7efcf2d74d13810f458 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_business_ethics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "business_ethics" +"description": "The following are multiple choice questions (with answers) about business\ + \ ethics.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_business_ethics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_clinical_knowledge.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_clinical_knowledge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38f799060fa6901b890d3a87d8aa9b9444d34b57 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_clinical_knowledge.yaml @@ -0,0 +1,6 @@ +"dataset_name": "clinical_knowledge" +"description": "The following are multiple choice questions (with answers) about clinical\ + \ knowledge.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_clinical_knowledge" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f36eb1f598f754154c2b15b24bbb650358c707c5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_biology" +"description": "The following are multiple choice questions (with answers) about college\ + \ biology.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_college_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0eccce652fade13a319af78e06a7528b11814302 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_chemistry" +"description": "The following are multiple choice questions (with answers) about college\ + \ chemistry.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_college_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd415aa10efaf96331d9fef82c5b6a2bb538263a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_computer_science" +"description": "The following are multiple choice questions (with answers) about college\ + \ computer science.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_college_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d062721102c0f6e6c09574398a60db74c26b593 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_mathematics" +"description": "The following are multiple choice questions (with answers) about college\ + \ mathematics.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_college_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..edc660d9c30dfad6666f5e1b4c679489f62c5991 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_medicine" +"description": "The following are multiple choice questions (with answers) about college\ + \ medicine.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_college_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aac8f400d1d9005376bfe3354753e87700a7bda8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_college_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_physics" +"description": "The following are multiple choice questions (with answers) about college\ + \ physics.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_college_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_computer_security.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..178c468346a5022a5d0031fd27c6b9a07ab24150 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_computer_security.yaml @@ -0,0 +1,6 @@ +"dataset_name": "computer_security" +"description": "The following are multiple choice questions (with answers) about computer\ + \ security.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_computer_security" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3cfbe6250d19aab6e60c9089f0feb91eed37423 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_conceptual_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "conceptual_physics" +"description": "The following are multiple choice questions (with answers) about conceptual\ + \ physics.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_conceptual_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_econometrics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad8704e4f8e3a60ee2ff7e370cf7394c0359aeb7 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_econometrics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "econometrics" +"description": "The following are multiple choice questions (with answers) about econometrics.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_econometrics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56eeae0183ca0c087b0a16aa317f2b93d5f1b87b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_electrical_engineering.yaml @@ -0,0 +1,6 @@ +"dataset_name": "electrical_engineering" +"description": "The following are multiple choice questions (with answers) about electrical\ + \ engineering.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_electrical_engineering" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..da3b3af2b5f310232cbd9c9ee63081acbb571638 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_elementary_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "elementary_mathematics" +"description": "The following are multiple choice questions (with answers) about elementary\ + \ mathematics.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_elementary_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d3f4edc644842cbc3fae865c96f99322daaafbf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_formal_logic.yaml @@ -0,0 +1,6 @@ +"dataset_name": "formal_logic" +"description": "The following are multiple choice questions (with answers) about formal\ + \ logic.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_formal_logic" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_global_facts.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4feef1895254438bde19ebfc3d7a36aee87e61de --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_global_facts.yaml @@ -0,0 +1,6 @@ +"dataset_name": "global_facts" +"description": "The following are multiple choice questions (with answers) about global\ + \ facts.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_global_facts" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..245d9be815c3644bf3298a0d093a76410b7487b6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_biology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school biology.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34eb30d32d5b6927d44d59a63f5a549587f414f1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_chemistry" +"description": "The following are multiple choice questions (with answers) about high\ + \ school chemistry.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34250a6c61cb5e29acbb99f8a080d45f74a91d45 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_computer_science" +"description": "The following are multiple choice questions (with answers) about high\ + \ school computer science.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42b7dd4d5aa2ab541b7f269c84845d262db452c5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_european_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_european_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school european history.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_european_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e67277aa5480e1a9465169112755c3da70e12e6e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_geography.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_geography" +"description": "The following are multiple choice questions (with answers) about high\ + \ school geography.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_geography" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..84643a74239db620816f0d8a67575d0c8268e58f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_government_and_politics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school government and politics.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_government_and_politics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb08333804237ac3e0584db637d5c91477a6a93d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_macroeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school macroeconomics.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_macroeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1ca028d8262f22807eb591c3e498fecabd9887b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_mathematics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school mathematics.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c60982b78dab4866a6827fe5b1bf9f2b710ed8d3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_microeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school microeconomics.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_microeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33b8d16739c9faf352ad242bd76b2bc33bc21aa6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_physics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school physics.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f47bbbb68c02a417e60e5b0a19f4f85c5723b41b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_psychology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school psychology.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_statistics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_statistics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..741971895ba27ad6651ac456def204a078ac5d3e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_statistics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_statistics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school statistics.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_statistics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48696971c9e850a18baadd6c3e9f958851cc2a3e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_us_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_us_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school us history.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_us_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ae6cfcbba3f86dc0339edc3a361c898e6c8716fd --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_high_school_world_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_world_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school world history.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_high_school_world_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_human_aging.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..677f119a754f0c671fae0f2285bb8ff29f2af85e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_human_aging.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_aging" +"description": "The following are multiple choice questions (with answers) about human\ + \ aging.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_human_aging" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d4e33d7d607ef2f07ea0fdb67305b8f88a45d13a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_human_sexuality.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_sexuality" +"description": "The following are multiple choice questions (with answers) about human\ + \ sexuality.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_human_sexuality" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_international_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_international_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac5d9d5a46b7f4f1daafb7c7f0feb66933c4829d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_international_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "international_law" +"description": "The following are multiple choice questions (with answers) about international\ + \ law.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_international_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2f135869aca516492cd9dc8ce210838173a1d7a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_jurisprudence.yaml @@ -0,0 +1,6 @@ +"dataset_name": "jurisprudence" +"description": "The following are multiple choice questions (with answers) about jurisprudence.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_jurisprudence" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6624e07743a432cc354ccff7af2363db2ec1ae11 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_logical_fallacies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "logical_fallacies" +"description": "The following are multiple choice questions (with answers) about logical\ + \ fallacies.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_logical_fallacies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ab6c459ae50e7311dc9d8819ec753c69f6d9583b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_machine_learning.yaml @@ -0,0 +1,6 @@ +"dataset_name": "machine_learning" +"description": "The following are multiple choice questions (with answers) about machine\ + \ learning.\n\n" +"tag": "mmlu_flan_n_shot_generative_stem" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_machine_learning" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_management.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4af9ded012e921feeb38d31cde98fef9888aba95 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_management.yaml @@ -0,0 +1,6 @@ +"dataset_name": "management" +"description": "The following are multiple choice questions (with answers) about management.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_management" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_marketing.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22ef9d3fd49556afd4578685099abc0bb9b64c9e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_marketing.yaml @@ -0,0 +1,6 @@ +"dataset_name": "marketing" +"description": "The following are multiple choice questions (with answers) about marketing.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_marketing" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c24da7938b431acdd991830424777e6645cf9bbb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_medical_genetics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "medical_genetics" +"description": "The following are multiple choice questions (with answers) about medical\ + \ genetics.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_medical_genetics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5b90845321c954cc2e7875fdc084e5935444af7 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_miscellaneous.yaml @@ -0,0 +1,6 @@ +"dataset_name": "miscellaneous" +"description": "The following are multiple choice questions (with answers) about miscellaneous.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_miscellaneous" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..295c39a6efce509983b01b18c20375866b08d3bc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_moral_disputes.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_disputes" +"description": "The following are multiple choice questions (with answers) about moral\ + \ disputes.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_moral_disputes" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f09f982f26462304a20420e9b61bf3ef941448a0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_moral_scenarios.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_scenarios" +"description": "The following are multiple choice questions (with answers) about moral\ + \ scenarios.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_moral_scenarios" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_nutrition.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf633f270a6d9fbbaa0a793bc5d5e48731a31d57 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_nutrition.yaml @@ -0,0 +1,6 @@ +"dataset_name": "nutrition" +"description": "The following are multiple choice questions (with answers) about nutrition.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_nutrition" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a5fe27eefb47badf4c13e87ad0fbac96b08283e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_philosophy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "philosophy" +"description": "The following are multiple choice questions (with answers) about philosophy.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_philosophy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_prehistory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_prehistory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60788fc6c201bf316398f48adc9575dcb806b649 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_prehistory.yaml @@ -0,0 +1,6 @@ +"dataset_name": "prehistory" +"description": "The following are multiple choice questions (with answers) about prehistory.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_prehistory" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f312af231f28d9343f7a0e2353cec110fda1f9a4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_accounting.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_accounting" +"description": "The following are multiple choice questions (with answers) about professional\ + \ accounting.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_professional_accounting" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be0533f0d8b90fc9f82226579ec849ac3f24be15 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_law" +"description": "The following are multiple choice questions (with answers) about professional\ + \ law.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_professional_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9cae6f8a5ec27d73bcf9b57e8597b377aee62835 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_medicine" +"description": "The following are multiple choice questions (with answers) about professional\ + \ medicine.\n\n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_professional_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21a39c51b7d246c3dd49e47ee0f5dd1865059c36 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_professional_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_psychology" +"description": "The following are multiple choice questions (with answers) about professional\ + \ psychology.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_professional_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_public_relations.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2687d99a279caac3f322ff178a1ea1ac7ea44f8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_public_relations.yaml @@ -0,0 +1,6 @@ +"dataset_name": "public_relations" +"description": "The following are multiple choice questions (with answers) about public\ + \ relations.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_public_relations" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_security_studies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_security_studies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c36a5522d3c0d6f165dbd5eaac9f5208822fb9d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_security_studies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "security_studies" +"description": "The following are multiple choice questions (with answers) about security\ + \ studies.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_security_studies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_sociology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ce0809907575855a8680ec1db533688ad42de46 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_sociology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "sociology" +"description": "The following are multiple choice questions (with answers) about sociology.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_sociology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_us_foreign_policy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_us_foreign_policy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56ed5e16281b6aca3720868538c93d2877d438b6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_us_foreign_policy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "us_foreign_policy" +"description": "The following are multiple choice questions (with answers) about us\ + \ foreign policy.\n\n" +"tag": "mmlu_flan_n_shot_generative_social_sciences" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_us_foreign_policy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_virology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..257dcfbf8a18c96d836d6db1214e8ff69ec63278 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_virology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "virology" +"description": "The following are multiple choice questions (with answers) about virology.\n\ + \n" +"tag": "mmlu_flan_n_shot_generative_other" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_virology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_world_religions.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_world_religions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39b64d03d3983f5c692a1a762c8457175dbf5408 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/mmlu_world_religions.yaml @@ -0,0 +1,6 @@ +"dataset_name": "world_religions" +"description": "The following are multiple choice questions (with answers) about world\ + \ religions.\n\n" +"tag": "mmlu_flan_n_shot_generative_humanities" +"include": "_mmlu_flan_generative_template_yaml" +"task": "mmlu_flan_n_shot_generative_world_religions" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72246935de8cf0cf8b256fd1e6c87dfbbb90a2ad --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/generative/utils.py @@ -0,0 +1,112 @@ +import re +import sys +import unicodedata + +from lm_eval.filters.extraction import RegexFilter + + +class MultiChoiceRegexFilter(RegexFilter): + """ """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, + ) -> None: + """ + regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure + - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. + - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. + group_select: Selects the (group_select)th match from the findall result. + ignore_case: Ignores the case during step 1 matching + ignore_punctuation: Remove the punctuation during step 1 matching + regexes_to_ignore: Remove these regexes during step 1 matching + """ + super().__init__(regex_pattern, group_select, fallback) + self.ignore_case = ignore_case + self.ignore_punctuation = ignore_punctuation + self.regexes_to_ignore = regexes_to_ignore + + def apply(self, resps, docs): + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + + def find_match(regex, resp, convert_dict={}): + match = regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m][0] + match = match.strip() + if match and match in convert_dict: + match = convert_dict[match] + return match + + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) + + def filter_ignores(st): + if self.regexes_to_ignore is not None: + for s in self.regexes_to_ignore: + st = re.sub(s, "", st) + + if self.ignore_case: + st = st.lower() + + if self.ignore_punctuation: + # https://stackoverflow.com/a/266162 + st = st.translate(punct_tbl) + return st + + filtered_resps = [] + + for r, doc in zip(resps, docs): + fallback_regexes = [] + choice_to_alpha = {} + next_alpha = "A" + + without_paren_fallback_regexes = [] + without_paren_to_target = {} + + choices = doc["choices"] + for c in choices: + m = filter_ignores(c.strip()) + fallback_regexes.append(f"{re.escape(m)}") + choice_to_alpha[m] = f"({next_alpha})" + + without_paren_fallback_regexes.append(next_alpha) + without_paren_to_target[next_alpha] = f"({next_alpha})" + + next_alpha = chr(ord(next_alpha) + 1) + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + f":[\s]*({without_paren_fallback_regex})" + ) + + filtered = [] + for resp in r: + match = find_match(self.regex, resp) + if not match: + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) + if not match: + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) + if not match: + match = self.fallback + filtered.append(match) + filtered_resps.append(filtered) + + return filtered_resps diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/_mmlu.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2cfa0fb9c30451fa79f6b8b038a01692c830f1a7 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/_mmlu.yaml @@ -0,0 +1,32 @@ +group: mmlu_flan_n_shot_loglikelihood +group_alias: mmlu (flan style, loglikelihood) +task: + - group: stem + task: + - mmlu_flan_n_shot_loglikelihood_stem + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: other + task: + - mmlu_flan_n_shot_loglikelihood_other + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: social sciences + task: + - mmlu_flan_n_shot_loglikelihood_social_sciences + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: humanities + task: + - mmlu_flan_n_shot_loglikelihood_humanities + aggregate_metric_list: + - metric: acc + weight_by_size: True +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/_mmlu_flan_loglikelihood_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/_mmlu_flan_loglikelihood_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..4605a4a15f2e84c4572388192fc1e51d717f70b1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/_mmlu_flan_loglikelihood_template_yaml @@ -0,0 +1,17 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +test_split: test +fewshot_split: dev +fewshot_config: + sampler: first_n +output_type: multiple_choice +doc_to_text: "Q: {{question.strip()}}\n(A) {{choices[0]}} (B) {{choices[1]}} (C) {{choices[2]}} (D) {{choices[3]}}\nA:" +doc_to_choice: ["(A)", "(B)", "(C)", "(D)"] +doc_to_target: answer +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true +metadata: + version: 2.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5dfa65ded384d6e1299b8e5564f5a655f2ced79 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_abstract_algebra.yaml @@ -0,0 +1,6 @@ +"dataset_name": "abstract_algebra" +"description": "The following are multiple choice questions (with answers) about abstract\ + \ algebra.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_abstract_algebra" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_anatomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e837e5d8fd3e1577af4d23d2120d1b55029f052f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_anatomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "anatomy" +"description": "The following are multiple choice questions (with answers) about anatomy.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_anatomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_astronomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43b9bc7ed89429c2d08cc74cc4472ebea28f67a2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_astronomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "astronomy" +"description": "The following are multiple choice questions (with answers) about astronomy.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_astronomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2438e6678be07c008922d83ea5016efab56ebc78 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_business_ethics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "business_ethics" +"description": "The following are multiple choice questions (with answers) about business\ + \ ethics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_business_ethics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_clinical_knowledge.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_clinical_knowledge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82d66adda5d600a94d5f6e36544dd63d2de3fece --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_clinical_knowledge.yaml @@ -0,0 +1,6 @@ +"dataset_name": "clinical_knowledge" +"description": "The following are multiple choice questions (with answers) about clinical\ + \ knowledge.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_clinical_knowledge" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15e6e75d3491dfd034df789a3481fb3a39dcaa02 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_biology" +"description": "The following are multiple choice questions (with answers) about college\ + \ biology.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_college_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b8c1bd3a8de310698082f738d287743d3731c23 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_chemistry" +"description": "The following are multiple choice questions (with answers) about college\ + \ chemistry.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_college_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1178c7b072f82bebdd4281a371d6105514a686e8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_computer_science" +"description": "The following are multiple choice questions (with answers) about college\ + \ computer science.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_college_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9776889b514c04c6c93aeedfd0ced7c620d11493 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_mathematics" +"description": "The following are multiple choice questions (with answers) about college\ + \ mathematics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_college_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8fdad90bd103ff616b4b14c2a3e9024208e149a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_medicine" +"description": "The following are multiple choice questions (with answers) about college\ + \ medicine.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_college_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77a89689127b4ca129b9434653198b051324fc0a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_college_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_physics" +"description": "The following are multiple choice questions (with answers) about college\ + \ physics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_college_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_computer_security.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e787e51745218e2465b739ee82b51c456bd228ab --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_computer_security.yaml @@ -0,0 +1,6 @@ +"dataset_name": "computer_security" +"description": "The following are multiple choice questions (with answers) about computer\ + \ security.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_computer_security" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..859e88e48a5cea7114b85c31c594f832520bacb0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_conceptual_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "conceptual_physics" +"description": "The following are multiple choice questions (with answers) about conceptual\ + \ physics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_conceptual_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_econometrics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0455a515eab5e3102a659d917758b942c00b952d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_econometrics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "econometrics" +"description": "The following are multiple choice questions (with answers) about econometrics.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_econometrics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b63e06172ec302a916f3be4b0a2ea0f1efa86674 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_electrical_engineering.yaml @@ -0,0 +1,6 @@ +"dataset_name": "electrical_engineering" +"description": "The following are multiple choice questions (with answers) about electrical\ + \ engineering.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_electrical_engineering" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79771d21543868dd73bf6ff84201ef07d79c89a2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_elementary_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "elementary_mathematics" +"description": "The following are multiple choice questions (with answers) about elementary\ + \ mathematics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_elementary_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e46d8e21c62ff03a6f47bbbc7a6d085840049a4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_formal_logic.yaml @@ -0,0 +1,6 @@ +"dataset_name": "formal_logic" +"description": "The following are multiple choice questions (with answers) about formal\ + \ logic.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_formal_logic" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_global_facts.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e7aff59325d7dab9a02c4eda3a886d062fe3b4a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_global_facts.yaml @@ -0,0 +1,6 @@ +"dataset_name": "global_facts" +"description": "The following are multiple choice questions (with answers) about global\ + \ facts.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_global_facts" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dfe33de2be1d2f821c92fc46111150e1ac366b7e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_biology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school biology.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..661ea0ca2f72242eb4daf520f6683a9de3a7c32c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_chemistry" +"description": "The following are multiple choice questions (with answers) about high\ + \ school chemistry.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b271a661f943fdd6d364833c9f994c19ee10cd22 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_computer_science" +"description": "The following are multiple choice questions (with answers) about high\ + \ school computer science.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1a329ebb24804c92690b5210cb27f6ec47be93d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_european_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_european_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school european history.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_european_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe681101f6704e7f058e27350b37838ba63fcd07 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_geography.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_geography" +"description": "The following are multiple choice questions (with answers) about high\ + \ school geography.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_geography" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8a8f279fc5bfaa8b610f3ff5dcd1c2be0c88e07 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_government_and_politics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school government and politics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_government_and_politics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45664135facb151e9b6f91347bbc135297880acb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_macroeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school macroeconomics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_macroeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49903260ceff13c03070606e04beb45d99d660f7 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_mathematics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school mathematics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..394c1d77e553a24820ba5db934bfa8fd95a8a269 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_microeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school microeconomics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_microeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f32ef2fcc4bf03e34b43c5a3d1135431742db71 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_physics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school physics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a9aac0736a9610469c70b925b70b3f384ca9777 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_psychology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school psychology.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_statistics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_statistics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e7e02afb94aebb1676c5c395c51e37d4f149a39 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_statistics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_statistics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school statistics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_statistics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bc84ea9dd78e87166f1e7b67c248d242cb98d83 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_us_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_us_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school us history.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_us_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f25cf646bebbca23ed23ea421473e6c2461dda8a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_high_school_world_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_world_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school world history.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_high_school_world_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_human_aging.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c258f919041775e1d2bf1226264a10b1133802db --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_human_aging.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_aging" +"description": "The following are multiple choice questions (with answers) about human\ + \ aging.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_human_aging" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e192a78b48bd37e4dc37efc5783b527f84c3e55 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_human_sexuality.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_sexuality" +"description": "The following are multiple choice questions (with answers) about human\ + \ sexuality.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_human_sexuality" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_international_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_international_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..662bf6eb35157889356a6be7ded31d5f6f2a39ac --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_international_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "international_law" +"description": "The following are multiple choice questions (with answers) about international\ + \ law.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_international_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82036dc1da79c464f21f90b46e4681b061fe5ea1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_jurisprudence.yaml @@ -0,0 +1,6 @@ +"dataset_name": "jurisprudence" +"description": "The following are multiple choice questions (with answers) about jurisprudence.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_jurisprudence" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..346e4b669771f23d7a3a805b329e96e711cd367e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_logical_fallacies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "logical_fallacies" +"description": "The following are multiple choice questions (with answers) about logical\ + \ fallacies.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_logical_fallacies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d7c280155ae7302b0bed56715c5ea92191e3faf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_machine_learning.yaml @@ -0,0 +1,6 @@ +"dataset_name": "machine_learning" +"description": "The following are multiple choice questions (with answers) about machine\ + \ learning.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_stem" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_machine_learning" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_management.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a732a778fb85eac5467fe2744e51340bce0c302 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_management.yaml @@ -0,0 +1,6 @@ +"dataset_name": "management" +"description": "The following are multiple choice questions (with answers) about management.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_management" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_marketing.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56760226dba043ba37a110cf7065bbd52c3e9c93 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_marketing.yaml @@ -0,0 +1,6 @@ +"dataset_name": "marketing" +"description": "The following are multiple choice questions (with answers) about marketing.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_marketing" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6635c9613155a7c23bf67329b4be950e57fe2d30 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_medical_genetics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "medical_genetics" +"description": "The following are multiple choice questions (with answers) about medical\ + \ genetics.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_medical_genetics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce8dff42a80057d6557f81e5aead49b4e93e4ef3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_miscellaneous.yaml @@ -0,0 +1,6 @@ +"dataset_name": "miscellaneous" +"description": "The following are multiple choice questions (with answers) about miscellaneous.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_miscellaneous" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62460e82f3386022679443efe3c989c2ffb59abf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_moral_disputes.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_disputes" +"description": "The following are multiple choice questions (with answers) about moral\ + \ disputes.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_moral_disputes" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..408c69f11630d4c237079e327b1b4c9fe3971dc9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_moral_scenarios.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_scenarios" +"description": "The following are multiple choice questions (with answers) about moral\ + \ scenarios.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_moral_scenarios" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_nutrition.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5494f9dc462494e198dfc7ad86d63a186637bf5c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_nutrition.yaml @@ -0,0 +1,6 @@ +"dataset_name": "nutrition" +"description": "The following are multiple choice questions (with answers) about nutrition.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_nutrition" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fd1f01a1707126eaf93e6a668d681408c8c7fe6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_philosophy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "philosophy" +"description": "The following are multiple choice questions (with answers) about philosophy.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_philosophy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_prehistory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_prehistory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1eb08bfbeb44dc8279ab6796e673b3b271517548 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_prehistory.yaml @@ -0,0 +1,6 @@ +"dataset_name": "prehistory" +"description": "The following are multiple choice questions (with answers) about prehistory.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_prehistory" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a23a990afe0abf5b354a15dfec3b5bbd2775fc9 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_accounting.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_accounting" +"description": "The following are multiple choice questions (with answers) about professional\ + \ accounting.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_professional_accounting" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a80f2baacbe17445f4a1ea564c7a72174b1c445 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_law" +"description": "The following are multiple choice questions (with answers) about professional\ + \ law.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_professional_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..da9e30e118445c2a796eb4145f5c308e9e33215f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_medicine.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_medicine" +"description": "The following are multiple choice questions (with answers) about professional\ + \ medicine.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_professional_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce7043a07273fe781cb56733ab02b0b1cb4bf059 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_professional_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_psychology" +"description": "The following are multiple choice questions (with answers) about professional\ + \ psychology.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_professional_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_public_relations.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..debace7ca0d0c1dbdfad6ad1621dcc5d9a1469eb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_public_relations.yaml @@ -0,0 +1,6 @@ +"dataset_name": "public_relations" +"description": "The following are multiple choice questions (with answers) about public\ + \ relations.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_public_relations" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_security_studies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_security_studies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb1f585ce890b5a4fcccc72c3066691509330c49 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_security_studies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "security_studies" +"description": "The following are multiple choice questions (with answers) about security\ + \ studies.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_security_studies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_sociology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0580f7ae31687590598a77fe950d282020d9be16 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_sociology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "sociology" +"description": "The following are multiple choice questions (with answers) about sociology.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_sociology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_us_foreign_policy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_us_foreign_policy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ff2d9ea791abf7bad56784b804bcadd5c82c077 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_us_foreign_policy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "us_foreign_policy" +"description": "The following are multiple choice questions (with answers) about us\ + \ foreign policy.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_social_sciences" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_us_foreign_policy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_virology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3edfd9528eed5d4199ebb0ba06a328ed8c50dd8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_virology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "virology" +"description": "The following are multiple choice questions (with answers) about virology.\n\ + \n" +"tag": "mmlu_flan_n_shot_loglikelihood_other" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_virology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_world_religions.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_world_religions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..765e70c8fc22dbd75ba495a6490ec788d4e44b7e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/flan_n_shot/loglikelihood/mmlu_world_religions.yaml @@ -0,0 +1,6 @@ +"dataset_name": "world_religions" +"description": "The following are multiple choice questions (with answers) about world\ + \ religions.\n\n" +"tag": "mmlu_flan_n_shot_loglikelihood_humanities" +"include": "_mmlu_flan_loglikelihood_template_yaml" +"task": "mmlu_flan_n_shot_loglikelihood_world_religions" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/_default_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/_default_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..71371402e1a1d74c961492c26b9f08d08d9acd28 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/_default_template_yaml @@ -0,0 +1,29 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +test_split: test +fewshot_split: dev +fewshot_config: + sampler: first_n +output_type: generate_until +doc_to_text: "{{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nAnswer:" +doc_to_target: "{{['A', 'B', 'C', 'D'][answer]}}" +generation_kwargs: + until: + - "" + - "\n" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_punctuation: true + ignore_case: true +filter_list: + - name: get_response + filter: + # Filter everything after the first letter + - function: "regex" + regex_pattern: "^([A-Z]).*" + - function: take_first +metadata: + version: 3.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/_mmlu.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..550caa37606f975110f5e4f425d27e594014c116 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/_mmlu.yaml @@ -0,0 +1,38 @@ +group: mmlu_generative +group_alias: mmlu (generative) +task: + - group: stem + task: + - mmlu_stem_generative + aggregate_metric_list: + - metric: exact_match + weight_by_size: true + filter_list: get_response + - group: other + task: + - mmlu_other_generative + aggregate_metric_list: + - metric: exact_match + weight_by_size: true + filter_list: get_response + - group: social sciences + task: + - mmlu_social_sciences_generative + aggregate_metric_list: + - metric: exact_match + weight_by_size: true + filter_list: get_response + - group: humanities + task: + - mmlu_humanities_generative + aggregate_metric_list: + - metric: exact_match + weight_by_size: true + filter_list: get_response +aggregate_metric_list: + - aggregation: mean + metric: exact_match + weight_by_size: true + filter_list: get_response +metadata: + version: 3 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17bfcafb79b113cffe93f6e90c68562b7eae7c95 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_abstract_algebra.yaml @@ -0,0 +1,7 @@ +"dataset_name": "abstract_algebra" +"description": "The following are multiple choice questions (with answers) about abstract\ + \ algebra.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_abstract_algebra_generative" +"task_alias": "abstract_algebra" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_anatomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72afc359a495af12d3dcb2b062c6442d92d45c88 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_anatomy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "anatomy" +"description": "The following are multiple choice questions (with answers) about anatomy.\n\ + \n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_anatomy_generative" +"task_alias": "anatomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_astronomy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0b41447e74a2b95732b102bfe5ed642d3d208d2b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_astronomy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "astronomy" +"description": "The following are multiple choice questions (with answers) about astronomy.\n\ + \n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_astronomy_generative" +"task_alias": "astronomy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7c15d443691af36dcdc761eb41b8673f3782d0b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_business_ethics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "business_ethics" +"description": "The following are multiple choice questions (with answers) about business\ + \ ethics.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_business_ethics_generative" +"task_alias": "business_ethics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_clinical_knowledge.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_clinical_knowledge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..24cd0b72d3f68fb00da90397979816b85ea1c76c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_clinical_knowledge.yaml @@ -0,0 +1,7 @@ +"dataset_name": "clinical_knowledge" +"description": "The following are multiple choice questions (with answers) about clinical\ + \ knowledge.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_clinical_knowledge_generative" +"task_alias": "clinical_knowledge" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ff9cc284007337e30369dd4864b2b723e8e6768 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_biology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_biology" +"description": "The following are multiple choice questions (with answers) about college\ + \ biology.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_college_biology_generative" +"task_alias": "college_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12d9ce3eab1332fa202cf6f99a52785865aed1a7 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_chemistry.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_chemistry" +"description": "The following are multiple choice questions (with answers) about college\ + \ chemistry.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_college_chemistry_generative" +"task_alias": "college_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..73d91c52acd76bf99ce1869296257d25143ad149 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_computer_science.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_computer_science" +"description": "The following are multiple choice questions (with answers) about college\ + \ computer science.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_college_computer_science_generative" +"task_alias": "college_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15ae9dded855610af45a15bab8aa56596bfaddd4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_mathematics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_mathematics" +"description": "The following are multiple choice questions (with answers) about college\ + \ mathematics.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_college_mathematics_generative" +"task_alias": "college_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0461ab7ae7dab9df6b10591fd14791a2cc3eff0f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_medicine.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_medicine" +"description": "The following are multiple choice questions (with answers) about college\ + \ medicine.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_college_medicine_generative" +"task_alias": "college_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d997d8974c99a549a2216a9bd9237f05a619e21 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_college_physics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "college_physics" +"description": "The following are multiple choice questions (with answers) about college\ + \ physics.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_college_physics_generative" +"task_alias": "college_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_computer_security.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee64d20100e25fc4bcf7f446b1e98acf042c4ab8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_computer_security.yaml @@ -0,0 +1,7 @@ +"dataset_name": "computer_security" +"description": "The following are multiple choice questions (with answers) about computer\ + \ security.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_computer_security_generative" +"task_alias": "computer_security" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..75764a2cbf542ba09a99ae252c76a103bf534a9f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_conceptual_physics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "conceptual_physics" +"description": "The following are multiple choice questions (with answers) about conceptual\ + \ physics.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_conceptual_physics_generative" +"task_alias": "conceptual_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_econometrics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43fec80ad3f505bedb810df609a8c6e8d2c2c0ed --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_econometrics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "econometrics" +"description": "The following are multiple choice questions (with answers) about econometrics.\n\ + \n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_econometrics_generative" +"task_alias": "econometrics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..130ec2b2aa2210322c1e2f86cdf6be31dd72bffc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_electrical_engineering.yaml @@ -0,0 +1,7 @@ +"dataset_name": "electrical_engineering" +"description": "The following are multiple choice questions (with answers) about electrical\ + \ engineering.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_electrical_engineering_generative" +"task_alias": "electrical_engineering" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4afd087dc47f27653b54ff48a27a187bc9af07bc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_elementary_mathematics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "elementary_mathematics" +"description": "The following are multiple choice questions (with answers) about elementary\ + \ mathematics.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_elementary_mathematics_generative" +"task_alias": "elementary_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72c28c0b188b8b8fd69ba9ed79595f0d173f71cf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_formal_logic.yaml @@ -0,0 +1,7 @@ +"dataset_name": "formal_logic" +"description": "The following are multiple choice questions (with answers) about formal\ + \ logic.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_formal_logic_generative" +"task_alias": "formal_logic" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_global_facts.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b788025ad5ddf0d859fc12a0d0f139c0975b16ba --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_global_facts.yaml @@ -0,0 +1,7 @@ +"dataset_name": "global_facts" +"description": "The following are multiple choice questions (with answers) about global\ + \ facts.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_global_facts_generative" +"task_alias": "global_facts" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3677842dcfc091bb28525889479a48096cbb854d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_biology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_biology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school biology.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_biology_generative" +"task_alias": "high_school_biology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2df93cab2a999a7d6d8e78d3ac9c3ce9aeddcf12 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_chemistry.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_chemistry" +"description": "The following are multiple choice questions (with answers) about high\ + \ school chemistry.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_chemistry_generative" +"task_alias": "high_school_chemistry" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec5dc7f89abd7ddc57438c71e0502fce1ac47279 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_computer_science.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_computer_science" +"description": "The following are multiple choice questions (with answers) about high\ + \ school computer science.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_computer_science_generative" +"task_alias": "high_school_computer_science" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9732754bbd7352957dbe299494083e17b960c1bc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_european_history.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_european_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school european history.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_european_history_generative" +"task_alias": "high_school_european_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..66b1a3c97a64f9ee7db414ab13d3146efba5612d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_geography.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_geography" +"description": "The following are multiple choice questions (with answers) about high\ + \ school geography.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_geography_generative" +"task_alias": "high_school_geography" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46861fdc1149b72d4ac3f347c0e09f679f6c6e54 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_government_and_politics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school government and politics.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_government_and_politics_generative" +"task_alias": "high_school_government_and_politics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ada415922b2b777f153cf387f9095cce9c75304b --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_macroeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school macroeconomics.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_macroeconomics_generative" +"task_alias": "high_school_macroeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b22a5888e61be187f5bbbca1e38171eecd6252d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_mathematics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_mathematics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school mathematics.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_mathematics_generative" +"task_alias": "high_school_mathematics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c59ff16270084981614d6f01065851c005039413 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_microeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school microeconomics.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_microeconomics_generative" +"task_alias": "high_school_microeconomics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21d846afb9c8c6b372d59ee462561bb8f67ae83e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_physics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_physics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school physics.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_physics_generative" +"task_alias": "high_school_physics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cd1321a5f17efca463edbc6711c197fb18c3a81d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_psychology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_psychology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school psychology.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_psychology_generative" +"task_alias": "high_school_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_statistics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_statistics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1442fb8df4168606151af5cc1dfd769bb2e70e3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_statistics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_statistics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school statistics.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_statistics_generative" +"task_alias": "high_school_statistics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4552a560f38e3ed5db503fa677548a11766873c2 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_us_history.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_us_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school us history.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_us_history_generative" +"task_alias": "high_school_us_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d510f22ff39219829e6a9030cb39dc2c43062ca4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_high_school_world_history.yaml @@ -0,0 +1,7 @@ +"dataset_name": "high_school_world_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school world history.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_high_school_world_history_generative" +"task_alias": "high_school_world_history" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_human_aging.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56352f4a8c86966853cdbafd68453d1ee85dbabb --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_human_aging.yaml @@ -0,0 +1,7 @@ +"dataset_name": "human_aging" +"description": "The following are multiple choice questions (with answers) about human\ + \ aging.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_human_aging_generative" +"task_alias": "human_aging" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a23559cfb36a380131573f46b30bbdb5f4656b42 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_human_sexuality.yaml @@ -0,0 +1,7 @@ +"dataset_name": "human_sexuality" +"description": "The following are multiple choice questions (with answers) about human\ + \ sexuality.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_human_sexuality_generative" +"task_alias": "human_sexuality" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_international_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_international_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..878df6f3cacb299a51afacca461204fdc4e3a782 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_international_law.yaml @@ -0,0 +1,7 @@ +"dataset_name": "international_law" +"description": "The following are multiple choice questions (with answers) about international\ + \ law.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_international_law_generative" +"task_alias": "international_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5782d81551072a0ff03d79c930f02edb64488f3 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_jurisprudence.yaml @@ -0,0 +1,7 @@ +"dataset_name": "jurisprudence" +"description": "The following are multiple choice questions (with answers) about jurisprudence.\n\ + \n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_jurisprudence_generative" +"task_alias": "jurisprudence" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..43e8e0168b9f4638cc80b76ff1a4edc8893212b4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_logical_fallacies.yaml @@ -0,0 +1,7 @@ +"dataset_name": "logical_fallacies" +"description": "The following are multiple choice questions (with answers) about logical\ + \ fallacies.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_logical_fallacies_generative" +"task_alias": "logical_fallacies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d39a4b53164ce8bb641c99fa50f24ace308d3f4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_machine_learning.yaml @@ -0,0 +1,7 @@ +"dataset_name": "machine_learning" +"description": "The following are multiple choice questions (with answers) about machine\ + \ learning.\n\n" +"tag": "mmlu_stem_generative" +"include": "_default_template_yaml" +"task": "mmlu_machine_learning_generative" +"task_alias": "machine_learning" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_management.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d51ea0d0aa41fb4b2579162111aa8ebd8ce8f6d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_management.yaml @@ -0,0 +1,7 @@ +"dataset_name": "management" +"description": "The following are multiple choice questions (with answers) about management.\n\ + \n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_management_generative" +"task_alias": "management" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_marketing.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..744385a2ea524d6f651851856e15aaf190eb847e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_marketing.yaml @@ -0,0 +1,7 @@ +"dataset_name": "marketing" +"description": "The following are multiple choice questions (with answers) about marketing.\n\ + \n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_marketing_generative" +"task_alias": "marketing" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7fea57959818525acdada5bf8a327b0ce96fefb0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_medical_genetics.yaml @@ -0,0 +1,7 @@ +"dataset_name": "medical_genetics" +"description": "The following are multiple choice questions (with answers) about medical\ + \ genetics.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_medical_genetics_generative" +"task_alias": "medical_genetics" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7e0fabc2536d4894526b680deba9a382ff9c3ff --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_miscellaneous.yaml @@ -0,0 +1,7 @@ +"dataset_name": "miscellaneous" +"description": "The following are multiple choice questions (with answers) about miscellaneous.\n\ + \n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_miscellaneous_generative" +"task_alias": "miscellaneous" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..61d2feee6a9cf4ed4d71b7c2f9aa68f5219c270a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_moral_disputes.yaml @@ -0,0 +1,7 @@ +"dataset_name": "moral_disputes" +"description": "The following are multiple choice questions (with answers) about moral\ + \ disputes.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_moral_disputes_generative" +"task_alias": "moral_disputes" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2aeb93f967f0811d3a2f1d886aedfb334a96714e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_moral_scenarios.yaml @@ -0,0 +1,7 @@ +"dataset_name": "moral_scenarios" +"description": "The following are multiple choice questions (with answers) about moral\ + \ scenarios.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_moral_scenarios_generative" +"task_alias": "moral_scenarios" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_nutrition.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..638ac8100b6f918ccaa0a3dc13946512d3c97b33 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_nutrition.yaml @@ -0,0 +1,7 @@ +"dataset_name": "nutrition" +"description": "The following are multiple choice questions (with answers) about nutrition.\n\ + \n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_nutrition_generative" +"task_alias": "nutrition" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..149894b8484cb1fad9ddad1fc5cb2c07a659aea1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_philosophy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "philosophy" +"description": "The following are multiple choice questions (with answers) about philosophy.\n\ + \n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_philosophy_generative" +"task_alias": "philosophy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_prehistory.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_prehistory.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e130e1baacc3f8a8f558b568336896668e84dd4f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_prehistory.yaml @@ -0,0 +1,7 @@ +"dataset_name": "prehistory" +"description": "The following are multiple choice questions (with answers) about prehistory.\n\ + \n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_prehistory_generative" +"task_alias": "prehistory" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a46792ec22d84ee3193996653f536084b9ab7861 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_accounting.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_accounting" +"description": "The following are multiple choice questions (with answers) about professional\ + \ accounting.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_professional_accounting_generative" +"task_alias": "professional_accounting" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f087657e579524b35bf7de4c0f81cb5b697caed4 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_law.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_law" +"description": "The following are multiple choice questions (with answers) about professional\ + \ law.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_professional_law_generative" +"task_alias": "professional_law" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_medicine.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_medicine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc80878980195f58ac5ae26a0a70589a47b325d5 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_medicine.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_medicine" +"description": "The following are multiple choice questions (with answers) about professional\ + \ medicine.\n\n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_professional_medicine_generative" +"task_alias": "professional_medicine" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d0b36ccde61e7edc33464a676d4fe0fcc25f3304 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_professional_psychology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "professional_psychology" +"description": "The following are multiple choice questions (with answers) about professional\ + \ psychology.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_professional_psychology_generative" +"task_alias": "professional_psychology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_public_relations.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37cdccba9b7cebbaa34c5f1e9da01655367477f6 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_public_relations.yaml @@ -0,0 +1,7 @@ +"dataset_name": "public_relations" +"description": "The following are multiple choice questions (with answers) about public\ + \ relations.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_public_relations_generative" +"task_alias": "public_relations" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_security_studies.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_security_studies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36c235feefd1548320400e7e8d9f3e03f2d478d0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_security_studies.yaml @@ -0,0 +1,7 @@ +"dataset_name": "security_studies" +"description": "The following are multiple choice questions (with answers) about security\ + \ studies.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_security_studies_generative" +"task_alias": "security_studies" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_sociology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7e2e592e4457118c9458ccb757b823f9adbb193 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_sociology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "sociology" +"description": "The following are multiple choice questions (with answers) about sociology.\n\ + \n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_sociology_generative" +"task_alias": "sociology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_us_foreign_policy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_us_foreign_policy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5fb95366245eae638918270bff4353024195d5f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_us_foreign_policy.yaml @@ -0,0 +1,7 @@ +"dataset_name": "us_foreign_policy" +"description": "The following are multiple choice questions (with answers) about us\ + \ foreign policy.\n\n" +"tag": "mmlu_social_sciences_generative" +"include": "_default_template_yaml" +"task": "mmlu_us_foreign_policy_generative" +"task_alias": "us_foreign_policy" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_virology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9954dc182f1bbd5030b94d2a08b2ddf4a135a6cf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_virology.yaml @@ -0,0 +1,7 @@ +"dataset_name": "virology" +"description": "The following are multiple choice questions (with answers) about virology.\n\ + \n" +"tag": "mmlu_other_generative" +"include": "_default_template_yaml" +"task": "mmlu_virology_generative" +"task_alias": "virology" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_world_religions.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_world_religions.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1db5128b43e615d0fc41f9c7448db3b5ea39942c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu/generative/mmlu_world_religions.yaml @@ -0,0 +1,7 @@ +"dataset_name": "world_religions" +"description": "The following are multiple choice questions (with answers) about world\ + \ religions.\n\n" +"tag": "mmlu_humanities_generative" +"include": "_default_template_yaml" +"task": "mmlu_world_religions_generative" +"task_alias": "world_religions" diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/README.md b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8fddbef266f755ae611c38b6ec2aea05ff6aa033 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/README.md @@ -0,0 +1,64 @@ +# mmlu_pro + +### Paper + +Title: `MMLU-Pro: A More Robust and Challenging Multi-Task Language Understanding Benchmark` + +Abstract: `In the age of large-scale language models, benchmarks like the Massive Multitask Language Understanding (MMLU) have been pivotal in pushing the boundaries of what AI can achieve in language comprehension and reasoning across diverse domains. However, as models continue to improve, their performance on these benchmarks has begun to plateau, making it increasingly difficult to discern differences in model capabilities. This paper introduces MMLU-Pro, an enhanced dataset designed to extend the mostly knowledge-driven MMLU benchmark by integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. Additionally, MMLU-Pro eliminates the trivial and noisy questions in MMLU. Our experimental results show that MMLU-Pro not only raises the challenge, causing a significant drop in accuracy by 16% to 33% compared to MMLU but also demonstrates greater stability under varying prompts. With 24 different prompt styles tested, the sensitivity of model scores to prompt variations decreased from 4-5% in MMLU to just 2% in MMLU-Pro. Additionally, we found that models utilizing Chain of Thought (CoT) reasoning achieved better performance on MMLU-Pro compared to direct answering, which is in stark contrast to the findings on the original MMLU, indicating that MMLU-Pro includes more complex reasoning questions. Our assessments confirm that MMLU-Pro is a more discriminative benchmark to better track progress in the field.` + +Homepage: https://huggingface.co/datasets/TIGER-Lab/MMLU-Pro + +### Citation + +```bibtex +@misc{wang2024mmlupro, + title={MMLU-Pro: A More Robust and Challenging Multi-Task Language Understanding Benchmark}, + author={Yubo Wang and Xueguang Ma and Ge Zhang and Yuansheng Ni and Abhranil Chandra and Shiguang Guo and Weiming Ren and Aaran Arulraj and Xuan He and Ziyan Jiang and Tianle Li and Max Ku and Kai Wang and Alex Zhuang and Rongqi Fan and Xiang Yue and Wenhu Chen}, + year={2024}, + eprint={2406.01574}, + archivePrefix={arXiv}, + primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'} +} +``` + +### Groups and Tasks + +#### Groups + +* `mmlu_pro`: 'All 14 subjects of the mmlu_pro dataset, evaluated following the methodology in mmlu's original implementation' + +#### Tasks + +The following tasks evaluate subjects in the mmlu_pro dataset +- `mmlu_pro_biology` +- `mmlu_pro_business` +- `mmlu_pro_chemistry` +- `mmlu_pro_computer_science` +- `mmlu_pro_economics` +- `mmlu_pro_engineering` +- `mmlu_pro_health` +- `mmlu_pro_history` +- `mmlu_pro_law` +- `mmlu_pro_math` +- `mmlu_pro_other` +- `mmlu_pro_philosophy` +- `mmlu_pro_physics` +- `mmlu_pro_psychology` + +### Checklist + +For adding novel benchmarks/datasets to the library: +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + + +If other tasks on this dataset are already supported: +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? + +### Changelog + +* (tasks, group) 2024-09-23 -- (version 1 --> version 2) + * Added one newline to task description(s) as per [reference implementation](https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/47b9891aacb8bd7cda29d5c5ba17b9434dd333bc/evaluate_from_local.py#L93) diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/_default_template_yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/_default_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c4f44b5c96769d70215403b06bccab33f1bbfb7 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/_default_template_yaml @@ -0,0 +1,33 @@ +dataset_path: TIGER-Lab/MMLU-Pro +test_split: test +fewshot_split: validation +fewshot_config: + sampler: first_n + doc_to_text: !function utils.fewshot_to_text + doc_to_target: "" +output_type: generate_until +doc_to_text: !function utils.doc_to_text +doc_to_target: answer +filter_list: + - name: "custom-extract" + filter: + - function: "regex" + regex_pattern: 'answer is \(?([ABCDEFGHIJ])\)?' + # regex_pattern: r".*[aA]nswer:\s*([A-J])", + - function: "take_first" +generation_kwargs: + until: + - "" + - "Q:" + - "<|im_end|>" + do_sample: false + temperature: 0.0 +num_fewshot: 5 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 1.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/_mmlu_pro.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/_mmlu_pro.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc3204127604d6eac759299f77d63ce9ef49d24e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/_mmlu_pro.yaml @@ -0,0 +1,23 @@ +group: mmlu_pro +task: + - mmlu_pro_biology + - mmlu_pro_business + - mmlu_pro_chemistry + - mmlu_pro_computer_science + - mmlu_pro_economics + - mmlu_pro_engineering + - mmlu_pro_health + - mmlu_pro_history + - mmlu_pro_law + - mmlu_pro_math + - mmlu_pro_other + - mmlu_pro_philosophy + - mmlu_pro_physics + - mmlu_pro_psychology +aggregate_metric_list: + - aggregation: mean + metric: exact_match + weight_by_size: true + filter_list: custom-extract +metadata: + version: 2.0 diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_biology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80aee85108ce8ac9452c0d0d07341e05b11b0b7a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_biology.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about biology. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_biology" +task_alias: "biology" +process_docs: !function utils.process_biology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_business.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_business.yaml new file mode 100644 index 0000000000000000000000000000000000000000..daf871f6bb5abd614c2058c0552389d41cbced50 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_business.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about business. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_business" +task_alias: "business" +process_docs: !function utils.process_business diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_chemistry.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5baf354ec202e66647bacdc9e9008617cd2d4244 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_chemistry.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about chemistry. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_chemistry" +task_alias: "chemistry" +process_docs: !function utils.process_chemistry diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_computer_science.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7de347373e7b92d8eb9ba33bbf8ec4fb2a3dbc49 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_computer_science.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about computer science. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_computer_science" +task_alias: "computer_science" +process_docs: !function utils.process_computer_science diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_economics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_economics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..274612783fb749aad10ec67698c3abf5dce13ebf --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_economics.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about economics. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_economics" +task_alias: "economics" +process_docs: !function utils.process_economics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_engineering.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dcf02f5029823ab83c1cebfbbf9819ac5ec4f1d0 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_engineering.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about engineering. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_engineering" +task_alias: "engineering" +process_docs: !function utils.process_engineering diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_health.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_health.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d161d1d81a29a06a9425bafc9e86211a2669b9be --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_health.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about health. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_health" +task_alias: "health" +process_docs: !function utils.process_health diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_history.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d28efd3e77ef31dc2df96847c90e7c76920007bc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_history.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about history. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_history" +task_alias: "history" +process_docs: !function utils.process_history diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_law.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ba99f16dbb8df07ab9272d8b47d23839a919bc9c --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_law.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about law. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_law" +task_alias: "law" +process_docs: !function utils.process_law diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_math.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_math.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0526f11f975f556f94c21acbc7074c0b4b76dd42 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_math.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about math. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_math" +task_alias: "math" +process_docs: !function utils.process_math diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_other.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_other.yaml new file mode 100644 index 0000000000000000000000000000000000000000..beb2ec9da7011dd1437828fc442ad21733cdf1a8 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_other.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about other. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_other" +task_alias: "other" +process_docs: !function utils.process_other diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_philosophy.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..99e5d65b4c0af9a2a35589ab104cb07be100cabc --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_philosophy.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about philosophy. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_philosophy" +task_alias: "philosophy" +process_docs: !function utils.process_philosophy diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_physics.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e7fa740bd58ebd55382bd8464abd9abaee3d96e --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_physics.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about physics. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_physics" +task_alias: "physics" +process_docs: !function utils.process_physics diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_psychology.yaml b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b28fb72d329d4d9755e45201275604d257851754 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/mmlu_pro_psychology.yaml @@ -0,0 +1,5 @@ +description: "The following are multiple choice questions (with answers) about psychology. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice.\n" +include: "_default_template_yaml" +task: "mmlu_pro_psychology" +task_alias: "psychology" +process_docs: !function utils.process_psychology diff --git a/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/utils.py b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..03117be5f165fd7edf40404bf9934b3753039f1d --- /dev/null +++ b/Prism/Dream/Dream_Baseline/eval_instruct/lm_eval/tasks/mmlu_pro/utils.py @@ -0,0 +1,63 @@ +from functools import partial + + +choices = [ + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", +] + + +def format_cot_example(example, including_answer=True): + prompt = "Question:\n" + question = example["question"] + options = example["options"] + prompt += question + "\n" + prompt += "Options:\n" + for i, opt in enumerate(options): + prompt += "{}. {}\n".format(choices[i], opt) + if including_answer: + cot_content = example["cot_content"].replace( + "A: Let's think step by step.", "Answer: Let's think step by step." + ) + prompt += cot_content + "\n\n" + else: + prompt += "Answer: Let's think step by step." + return prompt + + +doc_to_text = partial(format_cot_example, including_answer=False) +fewshot_to_text = partial(format_cot_example, including_answer=True) + + +def process_docs(dataset, subject): + return dataset.filter(lambda x: x["category"] == subject) + + +process_biology = partial(process_docs, subject="biology") +process_business = partial(process_docs, subject="business") +process_chemistry = partial(process_docs, subject="chemistry") +process_computer_science = partial(process_docs, subject="computer science") +process_economics = partial(process_docs, subject="economics") +process_engineering = partial(process_docs, subject="engineering") +process_health = partial(process_docs, subject="health") +process_history = partial(process_docs, subject="history") +process_law = partial(process_docs, subject="law") +process_math = partial(process_docs, subject="math") +process_other = partial(process_docs, subject="other") +process_philosophy = partial(process_docs, subject="philosophy") +process_physics = partial(process_docs, subject="physics") +process_psychology = partial(process_docs, subject="psychology") diff --git a/Prism/Dream/Dream_Baseline/src/diffllm/__init__.py b/Prism/Dream/Dream_Baseline/src/diffllm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/Dream/Dream_Baseline/src/diffllm/gen_utils.py b/Prism/Dream/Dream_Baseline/src/diffllm/gen_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e572f8cdabb75ced4b104df2e0dec071b234b77 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/src/diffllm/gen_utils.py @@ -0,0 +1,110 @@ +import torch +import torch.nn.functional as F + + +def q_sample( + input_ids, + maskable_mask, + mask_token_id, + min=0.0, + max=1.0, + eos_token_id=None, + t=None, + t_mask=None, +): + x_0 = input_ids + + if t_mask is None: + if t is None: + t = torch.rand((x_0.shape[0],), dtype=torch.float, device=input_ids.device) + t = min + (max - min) * t + u = torch.rand_like(x_0, dtype=torch.float) # t/T prob to mask + t_mask = (u < t[:, None]) & maskable_mask + + x_t = x_0.masked_fill(t_mask, mask_token_id) + + if eos_token_id is not None: + # get the last non-eos token index + last_non_eos_token_idx = ((input_ids != eos_token_id) | (~maskable_mask)).sum( + dim=-1 + ) - 1 + seq_len = x_0.shape[1] + + for i in range(x_0.shape[0]): + if last_non_eos_token_idx[i] < seq_len - 1: # with eos tokens + t_mask_at_eos = t_mask[ + i, last_non_eos_token_idx[i] + 1 + ] # use arbitrary eos token + # t_mask[i, last_non_eos_token_idx[i] + 2:] = False # only learn the first eos token + if t_mask_at_eos: + x_t[i, last_non_eos_token_idx[i] + 1 :] = mask_token_id + t_mask[i, last_non_eos_token_idx[i] + 1 :] = True + else: + x_t[i, last_non_eos_token_idx[i] + 1 :] = eos_token_id + t_mask[i, last_non_eos_token_idx[i] + 1 :] = False + + return x_t, t, t_mask # True means it's "MASK" token and should have loss + + +def top_p_logits(logits, top_p=None): + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) + mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) + return logits + + +def top_k_logits(logits, top_k=None): + top_k = min(top_k, logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) + return logits + + +def sample_tokens( + logits, + temperature=0.0, + top_p=None, + top_k=None, + margin_confidence=False, + neg_entropy=False, +): + + if temperature > 0: + logits = logits / temperature + if top_p is not None and top_p < 1: + logits = top_p_logits(logits, top_p) + if top_k is not None: + logits = top_k_logits(logits, top_k) + probs = torch.softmax(logits, dim=-1) + + if temperature > 0: + try: + x0 = torch.multinomial(probs, num_samples=1).squeeze(-1) + confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) + except: + confidence, x0 = probs.max(dim=-1) + else: + confidence, x0 = probs.max(dim=-1) + + if margin_confidence: + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # Extract top1 and top2 probabilities + top1_probs = sorted_probs[:, 0] + top2_probs = sorted_probs[:, 1] + # Calculate confidence as top1 - top2 + confidence = top1_probs - top2_probs + + if neg_entropy: + epsilon = 1e-10 + log_probs = torch.log(probs + epsilon) + confidence = torch.sum(probs * log_probs, dim=-1) + + return confidence, x0 diff --git a/Prism/Dream/Dream_Baseline/src/trainer/__init__.py b/Prism/Dream/Dream_Baseline/src/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/Dream/Dream_Baseline/src/trainer/config/sft_trainer.yaml b/Prism/Dream/Dream_Baseline/src/trainer/config/sft_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72b1d49d5d38a1afa5fc6eb82a4cb85e1aab737a --- /dev/null +++ b/Prism/Dream/Dream_Baseline/src/trainer/config/sft_trainer.yaml @@ -0,0 +1,57 @@ +data: + perbatch_cutoff: False + perbatch_cutoff_type: random + collate_fn_type: optimized + train_batch_size: 256 + micro_batch_size_per_gpu: 4 # this is also val batch size + train_files: ~/data/gsm8k/train.parquet + val_files: ~/data/gsm8k/test.parquet + tokenized: False + prompt_key: question + response_key: answer + max_length: 1024 + truncation: error + chat_template: null + resp_cutoff_ratio: 0.0 + pad_token_id: null +model: + partial_pretrain: ~/models/gemma-1.1-7b-it + fsdp_config: + wrap_policy: + min_num_params: 0 + cpu_offload: False + offload_params: False + external_lib: null + enable_gradient_checkpointing: False + trust_remote_code: False + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation + use_liger: False + attention_dropout: 0.0 +optim: + lr: 1e-5 + betas: [0.9, 0.95] + weight_decay: 0.01 + warmup_steps_ratio: 0.1 + clip_grad: 1.0 +ulysses_sequence_parallel_size: 1 +use_remove_padding: False +trainer: + default_local_dir: /tmp/sft_model + default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here + resume_path: null + resume_training: False + project_name: gsm8k-sft + experiment_name: test + total_epochs: 3 + total_training_steps: null + logger: ['console'] + seed: 1 + save_checkpoint_steps: 1000 +diffusion: + token_reweighting: false # use focal loss for token-level reweighting + alpha: 0.25 # for focal loss + gamma: 2 # for focal loss + time_reweighting: original # time-level reweighting strategy + cart_p: 0.1 diff --git a/Prism/Dream/Dream_Baseline/src/trainer/fsdp_sft_trainer.py b/Prism/Dream/Dream_Baseline/src/trainer/fsdp_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a46307dbf8bfb2a5082aa34c89ba76cd6b095ea1 --- /dev/null +++ b/Prism/Dream/Dream_Baseline/src/trainer/fsdp_sft_trainer.py @@ -0,0 +1,1159 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A lightweight one-file FSDP SFT Trainer +TODO(zhangchi.usc1992) +- Add calculation of mfu +- Add validation +""" + +import os + +os.environ["NCCL_DEBUG"] = "WARN" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +import logging +import math +import random +import re +from contextlib import nullcontext +from typing import List + +import hydra +import numpy as np +import torch +import torch.distributed +import verl.utils.hdfs_io as hdfs_io +from peft import LoraConfig, TaskType, get_peft_model +from tensordict import TensorDict +from torch import nn, optim +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.fsdp import CPUOffload +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.utils.data import DataLoader, DistributedSampler +from tqdm import tqdm +from transformers import AutoConfig, AutoModel, PreTrainedModel +from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.distributed import initialize_global_process_group +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.fsdp_utils import ( + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, +) +from verl.utils.model import compute_position_id_with_mask +from verl.utils.torch_functional import get_cosine_schedule_with_warmup +from verl.utils.tracking import Tracking +from verl.workers.sharding_manager import FSDPUlyssesShardingManager + +from src.diffllm.gen_utils import q_sample +from src.trainer.sft_dataset import SFTDataset, TokenizedSFTDataset + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) + + +def extract_step(path): + match = re.search(r"global_step_(\d+)", path) + if match: + return int(match.group(1)) + return None + + +def convert_to_regular_types(obj): + """Convert Hydra configs and other special types to regular Python types.""" + from omegaconf import DictConfig, ListConfig + + if isinstance(obj, (ListConfig, DictConfig)): + return ( + {k: convert_to_regular_types(v) for k, v in obj.items()} + if isinstance(obj, DictConfig) + else list(obj) + ) + elif isinstance(obj, (list, tuple)): + return [convert_to_regular_types(x) for x in obj] + elif isinstance(obj, dict): + return {k: convert_to_regular_types(v) for k, v in obj.items()} + return obj + + +def context_adaptive_reweight(seq_len, distribution="symmetric-geometric", **kwargs): + position_ids_l = np.arange(seq_len).reshape(-1, 1) + position_ids_r = np.arange(seq_len).reshape(1, -1) + distance = position_ids_l - position_ids_r + distance = torch.from_numpy(distance) + + def geometric_distribution(k, cart_p=0.8, **kwargs): + if not 0 < cart_p <= 1: + raise ValueError("p must be between 0 and 1") + + res = (math.log(cart_p) + (k.abs() - 1) * math.log(1 - cart_p)).exp() * 0.5 + res.masked_fill_(k == 0, 0) # ignore distance=0 + return res + + if distribution == "symmetric-geometric": + matrix = geometric_distribution(distance, **kwargs) + else: + raise ValueError(f"Unknown distribution {distribution}") + + return matrix + + +class OptimizedCollateFunction: + """ + Optimized collate function that completes preprocessing during data loading + Reduces GPU computation overhead and improves training efficiency + """ + + def __init__(self, config, tokenizer): + self.config = config + self.tokenizer = tokenizer + self.pad_eos_token_id = ( + config.data.pad_token_id + if config.data.pad_token_id is not None + else tokenizer.pad_token_id + ) + + # Cache configuration items to avoid repeated access + self.enable_perbatch_cutoff = getattr(config.data, "perbatch_cutoff", False) + self.perbatch_cutoff_type = getattr(config.data, "perbatch_cutoff_type", None) + self.resp_cutoff_ratio = getattr(config.data, "resp_cutoff_ratio", 0.0) + + self.random = random.Random(42) + self.np_random = np.random.RandomState(42) + + def __call__(self, batch_samples): + """ + Efficient collate function implementation + Args: + batch_samples: List of samples from dataset + Returns: + Preprocessed batch dict + """ + # Use default collate to merge samples into batch tensors + from torch.utils.data.dataloader import default_collate + + batch = default_collate(batch_samples) + + # Extract tensors + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"].bool() + position_ids = batch["position_ids"] + loss_mask = batch["loss_mask"].bool() + + # 1. Handle perbatch_cutoff related logic + if self.enable_perbatch_cutoff: + input_ids, attention_mask, position_ids, loss_mask = ( + self._apply_perbatch_cutoff( + input_ids, attention_mask, position_ids, loss_mask + ) + ) + + # 2. Handle response truncation + if self.resp_cutoff_ratio > 0.0: + input_ids, attention_mask, position_ids, loss_mask = ( + self._apply_resp_cutoff( + input_ids, attention_mask, position_ids, loss_mask + ) + ) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + } + + def _apply_perbatch_cutoff( + self, input_ids, attention_mask, position_ids, loss_mask + ): + """Apply perbatch cutoff logic""" + if self.perbatch_cutoff_type == "random": + non_pad_lens = (input_ids != self.pad_eos_token_id).sum(-1).cpu() + + # randomly choose a cutoff length from non_pad_lens + cutoff_seq_len = np.random.choice(non_pad_lens) + + # cutoff + input_ids = input_ids[:, :cutoff_seq_len] + attention_mask = attention_mask[:, :cutoff_seq_len] + position_ids = position_ids[:, :cutoff_seq_len] + loss_mask = loss_mask[:, :cutoff_seq_len] + elif self.perbatch_cutoff_type == "random_with_input_pad": + prompt_mask = loss_mask == 0 + response_mask = (loss_mask == 1) & (input_ids != self.pad_eos_token_id) + + prompt_lens = prompt_mask.sum(-1) + response_lens = response_mask.sum(-1) + max_prompt_len = prompt_lens.max() + pad_lens = max_prompt_len - prompt_lens + + # randomly choose a response length from response_lens + kept_response_len = np.random.choice(response_lens.cpu()) + + # rebuild input_ids, attention_mask, loss_mask + new_input_ids = ( + torch.ones( + input_ids.shape[0], + max_prompt_len + kept_response_len, + dtype=input_ids.dtype, + device=input_ids.device, + ) + * self.pad_eos_token_id + ) + new_attention_mask = torch.ones_like( + new_input_ids, dtype=attention_mask.dtype + ) + new_loss_mask = torch.ones_like(new_input_ids, dtype=loss_mask.dtype) + + for i in range(input_ids.shape[0]): + kept_response_len_i = min(kept_response_len, response_lens[i]) + new_input_ids[i, pad_lens[i] : pad_lens[i] + prompt_lens[i]] = ( + input_ids[i][prompt_mask[i]] + ) + new_input_ids[ + i, + pad_lens[i] + + prompt_lens[i] : pad_lens[i] + + prompt_lens[i] + + kept_response_len_i, + ] = input_ids[i][response_mask[i]][:kept_response_len_i] + + new_attention_mask[i, : pad_lens[i]] = 0 + new_loss_mask[i, : pad_lens[i] + prompt_lens[i]] = 0 + + input_ids = new_input_ids + attention_mask = new_attention_mask + position_ids = compute_position_id_with_mask(new_attention_mask) + loss_mask = new_loss_mask + else: + pad_lens = (input_ids == self.pad_eos_token_id).sum(-1) + # cutoff_len = eos_lens.min() - 1 + # assert cutoff_len > 0, input_ids + cutoff_len = pad_lens.min() + assert cutoff_len >= 0 + + # cutoff + seq_len = input_ids.shape[-1] + input_ids = input_ids[:, : seq_len - cutoff_len] + attention_mask = attention_mask[:, : seq_len - cutoff_len] + position_ids = position_ids[:, : seq_len - cutoff_len] + loss_mask = loss_mask[:, : seq_len - cutoff_len] + + return input_ids, attention_mask, position_ids, loss_mask + + def _apply_resp_cutoff(self, input_ids, attention_mask, position_ids, loss_mask): + """Apply response truncation logic""" + import numpy as np + + if self.np_random.rand() < self.resp_cutoff_ratio: + # Calculate response length for each sample (loss_mask True portion) + resp_lens = loss_mask.sum(-1) + min_resp_len = resp_lens.min().item() + + if min_resp_len > 1: + # Randomly select truncation length + cutoff_len = self.np_random.randint(1, min_resp_len) + + # Truncate from the end of sequence + new_seq_len = input_ids.shape[-1] - cutoff_len + input_ids = input_ids[:, :new_seq_len].contiguous() + attention_mask = attention_mask[:, :new_seq_len].contiguous() + position_ids = position_ids[:, :new_seq_len].contiguous() + loss_mask = loss_mask[:, :new_seq_len].contiguous() + + return input_ids, attention_mask, position_ids, loss_mask + + +class StreamingCollateFunction(OptimizedCollateFunction): + """ + Streaming version of collate function + Suitable for large batches or long sequences, reduces memory peaks + """ + + def __call__(self, batch_samples): + """ + Streaming version, processes samples one by one then concatenates + """ + if len(batch_samples) == 0: + return {} + + # Process samples one by one to reduce memory usage + processed_samples = [] + for sample in batch_samples: + processed_sample = self._preprocess_single_sample(sample) + processed_samples.append(processed_sample) + + # Concatenate all processed samples + from torch.utils.data.dataloader import default_collate + + batch = default_collate(processed_samples) + + # Apply batch-level preprocessing (inherited from parent class logic) + return self._apply_batch_preprocessing(batch) + + def _preprocess_single_sample(self, sample): + """Preprocess single sample""" + # Sample-level preprocessing can be implemented here + # e.g., dynamic truncation, special token processing, etc. + return sample + + def _apply_batch_preprocessing(self, batch): + """Apply batch-level preprocessing logic""" + # Extract tensors + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"].bool() + position_ids = batch["position_ids"] + loss_mask = batch["loss_mask"].bool() + + if self.enable_perbatch_cutoff: + input_ids, attention_mask, position_ids, loss_mask = ( + self._apply_perbatch_cutoff( + input_ids, attention_mask, position_ids, loss_mask + ) + ) + + if self.resp_cutoff_ratio > 0.0: + input_ids, attention_mask, position_ids, loss_mask = ( + self._apply_resp_cutoff( + input_ids, attention_mask, position_ids, loss_mask + ) + ) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + } + + +class FSDPSFTTrainer(object): + + def __init__( + self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh + ): + self.config = config + self.device_mesh = device_mesh + self.ulysses_device_mesh = ulysses_device_mesh + self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # Add tracking for current epoch + self.current_epoch = 0 + + # Check if resuming training + self.resume_training = getattr(self.config.trainer, "resume_training", False) + self.resume_checkpoint_path = getattr(self.config.trainer, "resume_path", None) + + # build tokenizer first + if self.resume_training and self.resume_checkpoint_path: + # If resuming from specific checkpoint, use that path for tokenizer + local_model_path = copy_local_path_from_hdfs( + src=self.resume_checkpoint_path, verbose=True + ) + else: + local_model_path = copy_local_path_from_hdfs( + src=self.config.model.partial_pretrain, verbose=True + ) + + from verl.utils import hf_tokenizer + + self.tokenizer = hf_tokenizer( + local_model_path, trust_remote_code=self.config.model.trust_remote_code + ) + if self.config.data.chat_template is not None: + raise ValueError("Apply Chat template from config is not supported yet.") + + # normalize dp size + self._normalize_config_bsz() + + # Set sequence parallel size + self.config.ulysses_sequence_parallel_size = getattr( + self.config, "ulysses_sequence_parallel_size", 1 + ) + self.use_remove_padding = getattr(self.config, "use_remove_padding", False) + if self.device_mesh.get_rank() == 0: + print( + f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}" + ) + print(f"Using remove padding: {self.use_remove_padding}") + + self._build_dataloader() + # build model + self._build_model_optimizer() + + # TODO: add checkpoint manager + if self.device_mesh.get_rank() == 0: + print(self.config) + + def _normalize_config_bsz(self): + dp_size = ( + self.device_mesh.size(0) + if not self.ulysses_device_mesh + else self.ulysses_device_mesh.size(0) + ) + if self.device_mesh.get_rank() == 0: + print(f"Normalize batch size by dp {dp_size}") + + assert ( + self.config.data.train_batch_size % dp_size == 0 + ), f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" + + self.config.data.train_batch_size //= dp_size + + assert ( + self.config.data.train_batch_size + % self.config.data.micro_batch_size_per_gpu + == 0 + ) + + def _build_dataloader(self): + config = self.config + # build dataset + if config.data.tokenized: + self.train_dataset = TokenizedSFTDataset( + parquet_files=config.data.train_files, + ) + self.val_dataset = TokenizedSFTDataset( + parquet_files=config.data.val_files, + ) + else: + self.train_dataset = SFTDataset( + parquet_files=config.data.train_files, + tokenizer=self.tokenizer, + prompt_key=config.data.prompt_key, + response_key=config.data.response_key, + max_length=config.data.max_length, + truncation=config.data.truncation, + pad_token_id=config.data.pad_token_id, + ) + self.val_dataset = SFTDataset( + parquet_files=config.data.val_files, + tokenizer=self.tokenizer, + prompt_key=config.data.prompt_key, + response_key=config.data.response_key, + max_length=config.data.max_length, + truncation=config.data.truncation, + pad_token_id=config.data.pad_token_id, + ) + + # build dataloader + # Use data parallel rank and size instead of global rank and world size + + # If doing SP, we need to use the local rank and size + if self.config.ulysses_sequence_parallel_size > 1: + rank = self.ulysses_device_mesh.get_local_rank("dp") + world_size = self.ulysses_device_mesh.size(0) + if self.ulysses_device_mesh.get_rank() == 0: + print( + f"Using SP rank {rank} and size {world_size} for data distribution" + ) + print( + f"Each SP rank gets different data, but the same data WITHIN the same rank" + ) + else: + rank = self.device_mesh.get_rank() + world_size = self.device_mesh.size() + if self.device_mesh.get_rank() == 0: + print(f"Using FSDP rank {rank} and size {world_size} for data distribution") + + self.train_sampler = DistributedSampler( + self.train_dataset, + shuffle=True, + num_replicas=world_size, + rank=rank, + drop_last=True, + ) + + # Create optimized collate function + collate_fn_type = getattr(config.data, "collate_fn_type", "optimized") + if collate_fn_type == "streaming": + train_collate_fn = StreamingCollateFunction(config, self.tokenizer) + val_collate_fn = StreamingCollateFunction(config, self.tokenizer) + else: + train_collate_fn = OptimizedCollateFunction(config, self.tokenizer) + val_collate_fn = OptimizedCollateFunction(config, self.tokenizer) + + if self.device_mesh.get_rank() == 0: + print(f"Using {collate_fn_type} collate function for data preprocessing") + + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=config.data.train_batch_size, + sampler=self.train_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + prefetch_factor=4, + persistent_workers=True, + collate_fn=train_collate_fn, # Use optimized collate function + ) + + self.val_sampler = DistributedSampler( + self.val_dataset, + shuffle=True, + num_replicas=world_size, + rank=rank, + drop_last=True, + ) + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=config.data.micro_batch_size_per_gpu, + sampler=self.val_sampler, + num_workers=8, + pin_memory=True, + drop_last=True, + prefetch_factor=4, + persistent_workers=True, + collate_fn=val_collate_fn, # Use optimized collate function + ) + + def _build_model_optimizer(self, checkpoint_path=None): + """Build model and optimizer, optionally from a checkpoint.""" + # Determine which path to load from + if checkpoint_path: + local_model_path = checkpoint_path + else: + local_model_path = copy_local_path_from_hdfs( + src=self.config.model.partial_pretrain, verbose=True + ) + + if self.config.model.get("external_lib", None) is not None: + # This is used to import external_lib into the huggingface systems + import importlib + + importlib.import_module(self.config.model.external_lib) + + log_gpu_memory_usage("Before model allocation", logger=logger) + + trust_remote_code = self.config.model.trust_remote_code + # load config first + config = AutoConfig.from_pretrained( + local_model_path, + trust_remote_code=trust_remote_code, + attention_dropout=self.config.model.attention_dropout, + ) + if self.config.ulysses_sequence_parallel_size > 1: + assert ( + self.use_remove_padding + ), "Sequence parallel is only supported when remove_padding is enabled" + from verl.models.registry import check_model_support_rmpad + + check_model_support_rmpad(config.model_type) + + if self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + + apply_monkey_patch(config, verbose=True) + + # This may be very large + init_context = get_init_weight_context_manager( + use_meta_tensor=not config.tie_word_embeddings + ) + + with init_context(): + self.model: PreTrainedModel = AutoModel.from_pretrained( + local_model_path, + config=config, + torch_dtype=torch.float32, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + # Apply Liger kernel if use_liger is enabled + if self.config.model.get("use_liger", False): + from liger_kernel.transformers.monkey_patch import ( + _apply_liger_kernel_to_instance, + ) + + _apply_liger_kernel_to_instance(model=self.model) + + if self.config.model.get("lora_rank", 0) > 0: + self.model.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types( + self.config.model.target_modules + ), + "bias": "none", + } + self.model = get_peft_model(self.model, LoraConfig(**lora_config)) + + if self.config.model.enable_gradient_checkpointing: + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + + log_gpu_memory_usage("After model allocation", logger=logger) + + mixed_precision = MixedPrecision( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) + + auto_wrap_policy = get_fsdp_wrap_policy( + self.model, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) + if self.device_mesh.get_rank() == 0: + print(auto_wrap_policy) + + if not self.config.model.fsdp_config.cpu_offload: + cpu_offload = None + else: + cpu_offload = CPUOffload( + offload_params=self.config.model.fsdp_config.offload_params + ) + + self.fsdp_model = FSDP( + module=self.model, + auto_wrap_policy=auto_wrap_policy, + param_init_fn=init_fn, + sharding_strategy=ShardingStrategy.FULL_SHARD, + mixed_precision=mixed_precision, + device_mesh=self.device_mesh, + sync_module_states=True, + device_id=torch.cuda.current_device(), + cpu_offload=cpu_offload, + use_orig_params=False, + ) + + log_gpu_memory_usage("After FSDP wrapping", logger=logger) + + self.optimizer = optim.AdamW( + self.fsdp_model.parameters(), + lr=self.config.optim.lr, + betas=self.config.optim.betas, + weight_decay=self.config.optim.weight_decay, + ) + + log_gpu_memory_usage("After initialize optimizer", logger=logger) + + self.steps_per_epoch = len(self.train_dataloader) + self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs + + if self.device_mesh.get_rank() == 0: + print( + f"Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}" + ) + + num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) + + self.lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=self.optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=self.total_steps, + ) + + def _load_from_checkpoint(self, checkpoint_path): + """Initialize training state from checkpoint.""" + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + + if self.device_mesh.get_rank() == 0: + print(f"Resuming from checkpoint: {checkpoint_path}") + + # Load training state + training_state_path = os.path.join(checkpoint_path, "training_state.pt") + optimizer_state_path = os.path.join(checkpoint_path, "optimizer_state.pt") + + # Only rank 0 loads the full state initially + if self.device_mesh.get_rank() == 0 and os.path.exists(training_state_path): + training_state = torch.load(training_state_path) + epoch = training_state["epoch"] + global_step = training_state["global_step"] + + # Load scheduler state + self.lr_scheduler.load_state_dict(training_state["lr_scheduler"]) + else: + # For other ranks or if file is missing, get step from path + epoch = 0 + global_step = extract_step(checkpoint_path) or 0 + + # Broadcast values to all ranks + if torch.distributed.get_world_size() > 1: + tensor = torch.tensor([epoch, global_step], device="cuda") + torch.distributed.broadcast(tensor, src=0) + if self.device_mesh.get_rank() != 0: + epoch, global_step = tensor.tolist() + + # Load optimizer state if exists + if os.path.exists(optimizer_state_path): + from torch.distributed.fsdp import FullStateDictConfig + from torch.distributed.fsdp.api import FullOptimStateDictConfig + + with FSDP.state_dict_type( + self.fsdp_model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + # Load optimizer state - rank 0 loads, others receive broadcast + if self.device_mesh.get_rank() == 0: + optim_state = torch.load(optimizer_state_path) + else: + optim_state = None + + # Use FSDP utility to load optimizer state + optim_state_dict = FSDP.scatter_full_optim_state_dict( + optim_state, self.fsdp_model + ) + self.optimizer.load_state_dict(optim_state_dict) + + self.current_epoch = epoch + return global_step + + def _compute_loss_and_backward(self, batch, do_backward=True): + """Compute loss with optional sequence parallelism and remove padding features""" + use_sp = ( + self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 + ) + + input_ids = batch["input_ids"].cuda(non_blocking=True) + attention_mask = batch["attention_mask"].cuda(non_blocking=True).bool() + position_ids = batch["position_ids"].cuda(non_blocking=True) + loss_mask = batch["loss_mask"].cuda(non_blocking=True).bool() + + loss_fct = nn.CrossEntropyLoss(reduction="none") + pad_eos_token_id = ( + self.config.data.pad_token_id + if self.config.data.pad_token_id is not None + else self.tokenizer.pad_token_id + ) + + # Context manager for sequence parallel if needed + context = self.sharding_manager if use_sp else nullcontext() + with context: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if not use_sp: + # Standard forward pass without sequence parallel + labels = input_ids.contiguous() + + # Forward pass + # NOTE: loss_mask is of size (batch_size, seq_len - 1) + batch_size = input_ids.shape[0] + masked_input_ids, t, loss_mask_nonflatten = q_sample( + input_ids, + maskable_mask=loss_mask, + mask_token_id=self.tokenizer.mask_token_id, + eos_token_id=( + pad_eos_token_id + if self.config.data.get("treat_eos_as_one", False) + else None + ), + ) + loss_mask = loss_mask_nonflatten.reshape(-1) + + # 2d -> 4d conversion for attention_mask + attention_mask = torch.logical_and( + attention_mask.unsqueeze(1).unsqueeze(-2), + attention_mask.unsqueeze(1).unsqueeze(-1), + ) + + output = self.fsdp_model( + input_ids=masked_input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ) + logits = output.logits + + shift_logits = torch.cat( + [logits[:, 0:1], logits[:, :-1]], dim=1 + ).contiguous() + shift_labels = labels.contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + # We use weighted loss + loss_mask = loss_mask.to(loss.device) + loss = loss.masked_fill(~loss_mask, 0) + if self.config.diffusion.token_reweighting: + loss = ( + self.config.diffusion.alpha + * (1 - torch.exp(-loss)) ** self.config.diffusion.gamma + * loss + ) + + if self.config.diffusion.time_reweighting == "original": + weight = 1 / t[:, None].float().expand(labels.size()) + elif self.config.diffusion.time_reweighting == "linear": + weight = 1 - t[:, None].float().expand(labels.size()) + elif self.config.diffusion.time_reweighting == "cart": + # seq_len = self.config.data.max_length + seq_len = input_ids.shape[-1] + weight_matrix = context_adaptive_reweight( + seq_len, cart_p=self.config.diffusion.cart_p + ) + _weight_matrix = weight_matrix[:seq_len, :seq_len].to( + loss_mask.device + ) + non_mask = ~loss_mask_nonflatten.to( + loss.device + ) # loss_mask indicates where is mask + weight = ( + non_mask.type_as(_weight_matrix) + .matmul(_weight_matrix) + .masked_fill(non_mask, 0) + ) + else: + weight = ( + t.new_ones((batch_size, 1)).float().expand(labels.size()) + ) + + loss = loss * weight.reshape(-1) + else: + raise NotImplementedError( + "Sequence parallel is not implemented yet" + ) + + valid_token_this_rank = torch.sum(loss_mask) + loss = torch.sum(loss) / valid_token_this_rank + + if do_backward: + loss.backward() + return loss + + def training_step(self, batch: TensorDict): + self.fsdp_model.train() + + log_gpu_memory_usage("Before optimizer zero_grad", logger=logger) + + self.optimizer.zero_grad() + + log_gpu_memory_usage("After optimizer zero_grad", logger=logger) + + micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) + n_micro_batches = len(micro_batches) + step_loss = 0 + for micro_batch in micro_batches: + loss = ( + self._compute_loss_and_backward(batch=micro_batch, do_backward=False) + / n_micro_batches + ) + loss.backward() + step_loss += loss.item() + + grad_norm = self.fsdp_model.clip_grad_norm_( + max_norm=self.config.optim.clip_grad + ) + + log_gpu_memory_usage("Before optimizer step", logger=logger) + + self.optimizer.step() + + log_gpu_memory_usage("After optimizer step", logger=logger) + + self.lr_scheduler.step() + + # reduce loss across dp ranks + lr = self.lr_scheduler.get_last_lr()[0] + + log_gpu_memory_usage("After offload weights", logger=logger) + + step_loss = torch.tensor(step_loss, device="cuda") + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + + # record how many tokens are padded as EOS + all_num_eos = [] + for micro_batch in micro_batches: + loss_mask = batch["loss_mask"] + num_eos = loss_mask.sum(-1) + num_eos -= num_eos.min() - 1 + all_num_eos.append(num_eos) + all_num_eos = torch.cat(all_num_eos) + + return { + "train/loss": step_loss.detach().item(), + "train/lr(1e-3)": lr * 1e3, + "train/grad_norm": grad_norm, + "train/num_eos_mean": all_num_eos.float().mean().item(), + "train/num_eos_max": all_num_eos.float().max().item(), + } + + def validation_step(self, batch: TensorDict): + self.fsdp_model.eval() + with torch.no_grad(): + loss = self._compute_loss_and_backward(batch, do_backward=False) + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + return loss + + def save_checkpoint(self, step): + """Save model, optimizer, and training state.""" + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + from torch.distributed.fsdp.api import FullOptimStateDictConfig + + # Create checkpoint directory + path = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{step}" + ) + + # Save model state + model_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + with FSDP.state_dict_type( + self.fsdp_model, StateDictType.FULL_STATE_DICT, model_cfg + ): + model_state = self.fsdp_model.state_dict() + optim_state = FSDP.full_optim_state_dict(self.fsdp_model, self.optimizer) + + # Save training state + training_state = { + "lr_scheduler": self.lr_scheduler.state_dict(), + "global_step": step, + "epoch": self.current_epoch, + } + + # Save on rank 0 only + if self.device_mesh.get_rank() == 0: + os.makedirs(path, exist_ok=True) + + # Save model using HF's save_pretrained + self.model.save_pretrained(path, state_dict=model_state) + self.tokenizer.save_pretrained(path) + + # Save optimizer and training state + torch.save(optim_state, os.path.join(path, "optimizer_state.pt")) + torch.save(training_state, os.path.join(path, "training_state.pt")) + + # Copy to HDFS if configured + if self.config.trainer.default_hdfs_dir: + hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) + hdfs_io.copy( + src=path, + dst=self.config.trainer.default_hdfs_dir, + dirs_exist_ok=True, + ) + torch.distributed.barrier() + + def _find_latest_checkpoint(self): + """Find the latest checkpoint in checkpoint directories.""" + latest_checkpoint = None + latest_step = -1 + + # Check local directory first + local_dir = self.config.trainer.default_local_dir + if os.path.exists(local_dir): + checkpoints = [ + d + for d in os.listdir(local_dir) + if os.path.isdir(os.path.join(local_dir, d)) + and d.startswith("global_step_") + ] + + for ckpt in checkpoints: + step = extract_step(ckpt) + if step is not None and step > latest_step: + latest_step = step + latest_checkpoint = os.path.join(local_dir, ckpt) + + # If not found locally and HDFS is configured, check there + if latest_checkpoint is None and self.config.trainer.default_hdfs_dir: + try: + if hdfs_io.exists(self.config.trainer.default_hdfs_dir): + checkpoints = [ + d + for d in hdfs_io.listdir(self.config.trainer.default_hdfs_dir) + if d.startswith("global_step_") + ] + for ckpt in checkpoints: + step = extract_step(ckpt) + if step is not None and step > latest_step: + latest_step = step + remote_path = os.path.join( + self.config.trainer.default_hdfs_dir, ckpt + ) + + # Copy from HDFS to local + local_path = os.path.join(local_dir, ckpt) + os.makedirs(local_dir, exist_ok=True) + hdfs_io.copy( + src=remote_path, dst=local_path, dirs_exist_ok=True + ) + latest_checkpoint = local_path + except Exception as e: + if self.device_mesh.get_rank() == 0: + print(f"Error checking HDFS for checkpoints: {e}") + + return latest_checkpoint + + def fit(self): + rank = self.device_mesh.get_rank() + + # TODO: add a unified tracking + if rank == 0: + tracking = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + ) + + global_step = 0 + + # Handle resuming training + if self.resume_training: + # Find latest checkpoint if not specified + if not self.resume_checkpoint_path: + self.resume_checkpoint_path = self._find_latest_checkpoint() + + if self.resume_checkpoint_path: + global_step = self._load_from_checkpoint(self.resume_checkpoint_path) + if rank == 0: + print( + f"Resumed training from step {global_step}, epoch {self.current_epoch}" + ) + elif rank == 0: + print("No checkpoint found, starting training from scratch") + + # Compute total training steps + total_training_steps = ( + len(self.train_dataloader) * self.config.trainer.total_epochs + ) + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + if rank == 0: + print(f"Total training steps: {self.total_training_steps}") + + # Begin training from the current epoch + for epoch in range(self.current_epoch, self.config.trainer.total_epochs): + self.current_epoch = epoch + self.train_sampler.set_epoch(epoch=epoch) + + # Create a data iterator + dataloader_iter = iter(self.train_dataloader) + + # If resuming mid-epoch, skip to the right position + if epoch == self.current_epoch and global_step > 0 and self.resume_training: + steps_in_epoch = global_step % self.steps_per_epoch + if steps_in_epoch > 0: + if rank == 0: + print( + f"Skipping {steps_in_epoch} steps to resume at the right position" + ) + for _ in range(steps_in_epoch): + try: + next(dataloader_iter) + except StopIteration: + dataloader_iter = iter(self.train_dataloader) + + # Calculate remaining steps in this epoch + remaining_steps = self.steps_per_epoch + if epoch == self.current_epoch and global_step > 0 and self.resume_training: + remaining_steps -= global_step % self.steps_per_epoch + + for data in tqdm( + dataloader_iter, + initial=self.steps_per_epoch - remaining_steps, + total=self.steps_per_epoch, + desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}", + ): + data = TensorDict( + data, batch_size=self.config.data.train_batch_size + ).cuda() + metric = self.training_step(data) + if rank == 0: + tracking.log(data=metric, step=global_step) + global_step += 1 + + # for early exit validation + if global_step >= self.total_training_steps: + # Perform final validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = TensorDict( + val_data, + batch_size=self.config.data.micro_batch_size_per_gpu, + ).cuda() + val_loss = self.validation_step(val_data) + val_losses.append(val_loss) + if rank == 0: + avg_val_loss = torch.mean(torch.stack(val_losses)) + metric = {"val/loss": avg_val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + torch.distributed.barrier() + + # Save final checkpoint + self.save_checkpoint(step=global_step) + return + + if global_step % self.config.trainer.save_checkpoint_steps == 0: + # Perform validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = TensorDict( + val_data, + batch_size=self.config.data.micro_batch_size_per_gpu, + ).cuda() + val_loss = self.validation_step(val_data) + val_losses.append(val_loss) + if rank == 0: + avg_val_loss = torch.mean(torch.stack(val_losses)) + metric = {"val/loss": avg_val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + torch.distributed.barrier() + + # Save checkpoint + self.save_checkpoint(step=global_step) + + # validation + val_losses = [] + for data in self.val_dataloader: + data = TensorDict( + data, batch_size=self.config.data.micro_batch_size_per_gpu + ).cuda() + val_loss = self.validation_step(data) + val_losses.append(val_loss) + if rank == 0: + val_loss = torch.mean(torch.stack(val_losses)) + metric = {"val/loss": val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + torch.distributed.barrier() + + # save checkpoint + self.save_checkpoint(step=global_step) + + +@hydra.main(config_path="config", config_name="sft_trainer", version_base=None) +def main(config): + local_rank, rank, world_size = initialize_global_process_group() + + device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",) + ) + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh( + device_type="cuda", + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=("dp", "sp"), + ) + trainer = FSDPSFTTrainer( + config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/Prism/Dream/Dream_Baseline/src/trainer/sft_dataset.py b/Prism/Dream/Dream_Baseline/src/trainer/sft_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c3e549f2a1e35d8f9e3236843e4652a7775c5f --- /dev/null +++ b/Prism/Dream/Dream_Baseline/src/trainer/sft_dataset.py @@ -0,0 +1,255 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +SFT dataset +- We assume user pass a single parquet file. +- We load all the data into the memory. +- **NOTE**: We support multi-turn prompts. +Each parquet file contains +""" + +from typing import List, Union + +import pandas as pd + +import torch +from datasets import Dataset as HFDataset +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer +from functools import partial + +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.model import compute_position_id_with_mask +from verl.utils import hf_tokenizer + + +class SFTDataset(Dataset): + """ + This is an in-memory SFTDataset + """ + + def __init__( + self, + parquet_files: Union[str, List[str]], + tokenizer, + prompt_key="prompt", + response_key="response", + max_length=1024, + truncation="error", + pad_token_id=None, + pad_input=False, + ): + assert truncation in ["error", "left", "right"] + self.truncation = truncation + + if not isinstance(parquet_files, List): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + if isinstance(tokenizer, str): + tokenizer = hf_tokenizer(tokenizer) + self.tokenizer: PreTrainedTokenizer = tokenizer + + self.prompt_key = prompt_key + self.response_key = response_key + + self.max_length = max_length + self.pad_token_id = ( + pad_token_id if pad_token_id is not None else self.tokenizer.pad_token_id + ) + self.pad_input = pad_input + self._download() + self._read_files_and_tokenize() + + def _download(self): + for i, parquet_file in enumerate(self.parquet_files): + self.parquet_files[i] = copy_local_path_from_hdfs( + parquet_file, verbose=True + ) + + def _read_files_and_tokenize(self): + + def series_to_item(ls): + import pandas, numpy + + while ( + isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) + and len(ls) == 1 + ): + ls = ls[0] + return ls + + dataframes = [] + for parquet_file in self.parquet_files: + # read parquet files and cache + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + self.dataframe = pd.concat(dataframes) + + @staticmethod + def _tokenize_static(example, tokenizer, prompt_key, response_key, max_length, truncation, pad_token_id): + prompt = example[prompt_key] + response = example[response_key] + + # apply chat template + if not isinstance(prompt, str): + prompt_chat = list(prompt) + else: + prompt_chat = [{"role": "user", "content": prompt}] + + # string + prompt_chat_str = tokenizer.apply_chat_template( + prompt_chat, add_generation_prompt=True, tokenize=False + ) + response_chat_str = response + tokenizer.eos_token + + # tokenize + prompt_ids_output = tokenizer( + prompt_chat_str, return_tensors="pt", add_special_tokens=False + ) + prompt_ids = prompt_ids_output["input_ids"][0] + prompt_attention_mask = prompt_ids_output["attention_mask"][0] + + response_ids_output = tokenizer( + response_chat_str, return_tensors="pt", add_special_tokens=False + ) + response_ids = response_ids_output["input_ids"][0] + response_attention_mask = response_ids_output["attention_mask"][0] + + prompt_length = prompt_ids.shape[0] + response_length = response_ids.shape[0] + + input_ids = torch.cat((prompt_ids, response_ids), dim=-1) + attention_mask = torch.cat( + (prompt_attention_mask, response_attention_mask), dim=-1 + ) + + # padding to max length + sequence_length = input_ids.shape[0] + if sequence_length < max_length: + padded_input_ids = ( + torch.ones( + size=(max_length - sequence_length,), dtype=input_ids.dtype + ) + * pad_token_id + ) + padded_attention_mask = torch.ones( # NOTE: we use 1 here + size=(max_length - sequence_length,), dtype=attention_mask.dtype + ) + + input_ids = torch.cat((input_ids, padded_input_ids)) + attention_mask = torch.cat((attention_mask, padded_attention_mask)) + elif sequence_length > max_length: + if truncation == "left": + # actually, left truncation may not be reasonable + input_ids = input_ids[-max_length :] + attention_mask = attention_mask[-max_length :] + elif truncation == "right": + input_ids = input_ids[: max_length] + attention_mask = attention_mask[: max_length] + elif truncation == "error": + raise NotImplementedError( + f"{sequence_length=} is larger than {max_length=}" + ) + else: + raise NotImplementedError( + f"Unknown truncation method {truncation}" + ) + + position_ids = compute_position_id_with_mask(attention_mask) + + loss_mask = attention_mask.clone() + if prompt_length > 1: + # mask out prompt for SFT. + loss_mask[: min(prompt_length, loss_mask.size(0))] = 0 + + return { + "input_ids": input_ids.numpy(), + "attention_mask": attention_mask.numpy(), + "position_ids": position_ids.numpy(), + "loss_mask": loss_mask.numpy(), + } + + def _tokenize(self, example): + return self._tokenize_static( + example, + self.tokenizer, + self.prompt_key, + self.response_key, + self.max_length, + self.truncation, + self.pad_token_id + ) + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, item): + example = self.dataframe.iloc[item] + data = self._tokenize(example) + return { + "input_ids": torch.tensor(data["input_ids"]), + "attention_mask": torch.tensor(data["attention_mask"]), + "position_ids": torch.tensor(data["position_ids"]), + "loss_mask": torch.tensor(data["loss_mask"]), + } + + def save_tokenized(self, path, num_proc=16): + hf_dataset = HFDataset.from_pandas(self.dataframe) + tokenize_fn = partial( + self._tokenize_static, + tokenizer=self.tokenizer, + prompt_key=self.prompt_key, + response_key=self.response_key, + max_length=self.max_length, + truncation=self.truncation, + pad_token_id=self.pad_token_id + ) + hf_dataset = hf_dataset.map(tokenize_fn, num_proc=num_proc) + hf_dataset.to_pandas().to_parquet(path) + + +class TokenizedSFTDataset(Dataset): + """ + This is an in-memory tokenized SFTDataset + """ + + def __init__( + self, + parquet_files: Union[str, List[str]], + ): + if not isinstance(parquet_files, List): + parquet_files = [parquet_files] + + self.parquet_files = parquet_files + self._read_files() + + def _read_files(self): + dataframes = [] + for parquet_file in self.parquet_files: + # read parquet files and cache + dataframe = pd.read_parquet(parquet_file) + dataframes.append(dataframe) + dataframe = pd.concat(dataframes) + self.hf_dataset = HFDataset.from_pandas(dataframe) + self.hf_dataset.set_format( + type="torch", + columns=["input_ids", "attention_mask", "position_ids", "loss_mask"], + ) + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, item): + return self.hf_dataset[item] diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..745a892568bd84b38252e20bbc9a0bea73ddb1db --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu.yaml @@ -0,0 +1,32 @@ +group: mmlu_flan_cot_zeroshot +group_alias: mmlu (flan style, zeroshot cot) +task: + - group: stem + task: + - mmlu_flan_cot_zeroshot_stem + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: other + task: + - mmlu_flan_cot_zeroshot_other + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: social sciences + task: + - mmlu_flan_cot_zeroshot_social_sciences + aggregate_metric_list: + - metric: acc + weight_by_size: True + - group: humanities + task: + - mmlu_flan_cot_zeroshot_humanities + aggregate_metric_list: + - metric: acc + weight_by_size: True +aggregate_metric_list: + - metric: acc + weight_by_size: True +metadata: + version: 2 diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu_flan_cot_zeroshot_template_yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu_flan_cot_zeroshot_template_yaml new file mode 100644 index 0000000000000000000000000000000000000000..7588b67e1905dc7eae3790c33a16de460c465f67 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/_mmlu_flan_cot_zeroshot_template_yaml @@ -0,0 +1,38 @@ +dataset_path: hails/mmlu_no_train # a copy of `cais/mmlu` with no auxiliary_train split +validation_split: validation +fewshot_split: dev +output_type: generate_until +doc_to_text: "Q: {{question.strip()}}\n(A) {{choices[0]}} (B) {{choices[1]}} (C) {{choices[2]}} (D) {{choices[3]}}\nA: Let's think step by step." +doc_to_target: "{{['(A)', '(B)', '(C)', '(D)'][answer]}}" +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "((?<=The answer is )(.*)(?=.)|(?<=answer is )(.*)(?=.)|(?<=The answer: )(.*)(?=.)|(?<=The final answer: )(.*)(?=.))" + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "multi_choice_regex" + group_select: -1 + ignore_case: true + ignore_punctuation: true + regex_pattern: "(\\([A-Z]\\))" + - function: "take_first" +generation_kwargs: + until: + - "" + - "Q:" + - "<|im_end|>" + do_sample: false + temperature: 0.0 +num_fewshot: 0 +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true + ignore_case: true + ignore_punctuation: true +metadata: + version: 3.0 +dataset_kwargs: + trust_remote_code: true diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_abstract_algebra.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_abstract_algebra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e885b818eae4bbc87374c756b68ecd11e44bd69 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_abstract_algebra.yaml @@ -0,0 +1,6 @@ +"dataset_name": "abstract_algebra" +"description": "The following are multiple choice questions (with answers) about abstract\ + \ algebra.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_abstract_algebra" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_anatomy.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_anatomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f17410a7cc0869223730328f55803d8d424e930 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_anatomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "anatomy" +"description": "The following are multiple choice questions (with answers) about anatomy.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_anatomy" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_astronomy.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_astronomy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5b821f97642ad5987244a0ac4c9988c2fca3857 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_astronomy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "astronomy" +"description": "The following are multiple choice questions (with answers) about astronomy.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_astronomy" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_business_ethics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_business_ethics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b26c679e26b6bd04d77eb5e0bb2ebaddcc515561 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_business_ethics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "business_ethics" +"description": "The following are multiple choice questions (with answers) about business\ + \ ethics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_business_ethics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_biology.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de020f4eaca7fdeb650688f034ee3b5d89490ddc --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_biology" +"description": "The following are multiple choice questions (with answers) about college\ + \ biology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_biology" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_chemistry.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8e5bbcf76b9fb3ad012511b213ffbbd554cd58d --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_chemistry" +"description": "The following are multiple choice questions (with answers) about college\ + \ chemistry.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_chemistry" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_computer_science.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04b5e750949984abcd7889be80485e52c97dba9f --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_computer_science" +"description": "The following are multiple choice questions (with answers) about college\ + \ computer science.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_computer_science" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_mathematics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81c59cc2c20f340a76ed3d945e976ce3c832815c --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_mathematics" +"description": "The following are multiple choice questions (with answers) about college\ + \ mathematics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_mathematics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_physics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82c2bb2ab586be2346237a6aa8b2ea9fd9170c97 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_college_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "college_physics" +"description": "The following are multiple choice questions (with answers) about college\ + \ physics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_college_physics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_computer_security.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_computer_security.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78216a44778fa0f9f1e057d5dc45b998fd5e87fc --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_computer_security.yaml @@ -0,0 +1,6 @@ +"dataset_name": "computer_security" +"description": "The following are multiple choice questions (with answers) about computer\ + \ security.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_computer_security" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_conceptual_physics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_conceptual_physics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52304bdf8eeac624c63331b259255a98866dc2ac --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_conceptual_physics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "conceptual_physics" +"description": "The following are multiple choice questions (with answers) about conceptual\ + \ physics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_conceptual_physics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_econometrics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_econometrics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c5be81c442710f91ad3e1ca6a0651105b2f14e24 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_econometrics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "econometrics" +"description": "The following are multiple choice questions (with answers) about econometrics.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_econometrics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_electrical_engineering.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_electrical_engineering.yaml new file mode 100644 index 0000000000000000000000000000000000000000..934a1a20a69d987904fe9c8b605c93e4ed149309 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_electrical_engineering.yaml @@ -0,0 +1,6 @@ +"dataset_name": "electrical_engineering" +"description": "The following are multiple choice questions (with answers) about electrical\ + \ engineering.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_electrical_engineering" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_elementary_mathematics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_elementary_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96ec81d6a8716ad60a4b3215faa42f3c3b1396d7 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_elementary_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "elementary_mathematics" +"description": "The following are multiple choice questions (with answers) about elementary\ + \ mathematics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_elementary_mathematics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_formal_logic.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_formal_logic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..915c96de78b68bdd2b8b8cbb26f2f8ec0ae24167 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_formal_logic.yaml @@ -0,0 +1,6 @@ +"dataset_name": "formal_logic" +"description": "The following are multiple choice questions (with answers) about formal\ + \ logic.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_formal_logic" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_global_facts.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_global_facts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a1f7491590b80e784360ceb72619efe4d9568f1 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_global_facts.yaml @@ -0,0 +1,6 @@ +"dataset_name": "global_facts" +"description": "The following are multiple choice questions (with answers) about global\ + \ facts.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_global_facts" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_biology.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_biology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c4043d9bd7e6a38d702afa7ccb4028e98001445 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_biology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_biology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school biology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_biology" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_chemistry.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_chemistry.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5aee89159d40e4f7c788cf670d9fa2e405d32c75 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_chemistry.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_chemistry" +"description": "The following are multiple choice questions (with answers) about high\ + \ school chemistry.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_chemistry" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_computer_science.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_computer_science.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb3eb2134bf8e3e8b8e81f29432db3e81b5f2fcf --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_computer_science.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_computer_science" +"description": "The following are multiple choice questions (with answers) about high\ + \ school computer science.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_computer_science" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_european_history.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_european_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6fc261e8fe114ffc9d7be99110d659704018f159 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_european_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_european_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school european history.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_european_history" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_geography.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_geography.yaml new file mode 100644 index 0000000000000000000000000000000000000000..baabc83a9e25b700600fe516d9a84833c32f4f29 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_geography.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_geography" +"description": "The following are multiple choice questions (with answers) about high\ + \ school geography.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_geography" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_government_and_politics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_government_and_politics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41365c509da451280527720e651d5793d1b83960 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_government_and_politics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_government_and_politics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school government and politics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_government_and_politics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_macroeconomics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_macroeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..05e62fa85cb3fdf871ec246de43d32c7a5209db1 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_macroeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_macroeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school macroeconomics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_macroeconomics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_mathematics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_mathematics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c9a9ca3b3840ee7169b59a53cec4c595c783cd4e --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_mathematics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_mathematics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school mathematics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_mathematics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_microeconomics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_microeconomics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2fb8639003555bdca712f3dc49ed6e463158be42 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_microeconomics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_microeconomics" +"description": "The following are multiple choice questions (with answers) about high\ + \ school microeconomics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_microeconomics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_psychology.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..999f9be74e2bc278a068c344030ae27f3b2c3006 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_psychology" +"description": "The following are multiple choice questions (with answers) about high\ + \ school psychology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_psychology" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_us_history.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_us_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d09cdcaa3b268d599e055f82c92779d4ecd2bcb --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_us_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_us_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school us history.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_us_history" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_world_history.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_world_history.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28a63b1b9106219486b5487b24396baf44179276 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_high_school_world_history.yaml @@ -0,0 +1,6 @@ +"dataset_name": "high_school_world_history" +"description": "The following are multiple choice questions (with answers) about high\ + \ school world history.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_high_school_world_history" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_aging.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_aging.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5a71bfc38aab72f17a01e3da11fc037ce28ef033 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_aging.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_aging" +"description": "The following are multiple choice questions (with answers) about human\ + \ aging.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_human_aging" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_sexuality.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_sexuality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa9b895b7331b051385a31165c725c2ef976db69 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_human_sexuality.yaml @@ -0,0 +1,6 @@ +"dataset_name": "human_sexuality" +"description": "The following are multiple choice questions (with answers) about human\ + \ sexuality.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_human_sexuality" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_jurisprudence.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_jurisprudence.yaml new file mode 100644 index 0000000000000000000000000000000000000000..642e6ce4f34992cb5be8b840ea481c7a389d9ce8 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_jurisprudence.yaml @@ -0,0 +1,6 @@ +"dataset_name": "jurisprudence" +"description": "The following are multiple choice questions (with answers) about jurisprudence.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_jurisprudence" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_logical_fallacies.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_logical_fallacies.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12594895469fbf0644e1908e4299f93f417703e8 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_logical_fallacies.yaml @@ -0,0 +1,6 @@ +"dataset_name": "logical_fallacies" +"description": "The following are multiple choice questions (with answers) about logical\ + \ fallacies.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_logical_fallacies" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_machine_learning.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_machine_learning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c27feea94ce017e35bcd453d6cbf5c4db5b3334 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_machine_learning.yaml @@ -0,0 +1,6 @@ +"dataset_name": "machine_learning" +"description": "The following are multiple choice questions (with answers) about machine\ + \ learning.\n\n" +"tag": "mmlu_flan_cot_zeroshot_stem" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_machine_learning" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_management.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_management.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1a13763a2bd796821efa251071359ce0acbf1cf --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_management.yaml @@ -0,0 +1,6 @@ +"dataset_name": "management" +"description": "The following are multiple choice questions (with answers) about management.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_management" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_marketing.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_marketing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0fe6e44b7fe464396e85a53f70831bbb48ff8ece --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_marketing.yaml @@ -0,0 +1,6 @@ +"dataset_name": "marketing" +"description": "The following are multiple choice questions (with answers) about marketing.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_marketing" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_medical_genetics.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_medical_genetics.yaml new file mode 100644 index 0000000000000000000000000000000000000000..813b6a3fe90413bd35a11f82624df600d8bf682b --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_medical_genetics.yaml @@ -0,0 +1,6 @@ +"dataset_name": "medical_genetics" +"description": "The following are multiple choice questions (with answers) about medical\ + \ genetics.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_medical_genetics" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_miscellaneous.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_miscellaneous.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2a95e892a8e6d357e6a9f771272d06422b14d1a --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_miscellaneous.yaml @@ -0,0 +1,6 @@ +"dataset_name": "miscellaneous" +"description": "The following are multiple choice questions (with answers) about miscellaneous.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_miscellaneous" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_disputes.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_disputes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6a76a2a7930589f3603fa070e974116b4996e96 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_disputes.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_disputes" +"description": "The following are multiple choice questions (with answers) about moral\ + \ disputes.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_moral_disputes" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_scenarios.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_scenarios.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a5f8c4e6f144dcb4c0eb6881b095434c76105bb6 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_moral_scenarios.yaml @@ -0,0 +1,6 @@ +"dataset_name": "moral_scenarios" +"description": "The following are multiple choice questions (with answers) about moral\ + \ scenarios.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_moral_scenarios" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_nutrition.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_nutrition.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0f144cb44e5218d3a70193fddca2a2883e6b1b8 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_nutrition.yaml @@ -0,0 +1,6 @@ +"dataset_name": "nutrition" +"description": "The following are multiple choice questions (with answers) about nutrition.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_nutrition" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_philosophy.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_philosophy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4e4c0c4b6ccd34ebf4ff1133d0e26ddd8dc90d9 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_philosophy.yaml @@ -0,0 +1,6 @@ +"dataset_name": "philosophy" +"description": "The following are multiple choice questions (with answers) about philosophy.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_philosophy" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_accounting.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_accounting.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e94bef0581e5290ff4790b5d48863a198a904879 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_accounting.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_accounting" +"description": "The following are multiple choice questions (with answers) about professional\ + \ accounting.\n\n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_professional_accounting" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_law.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_law.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25239d9a35941d49797c15986cc43213b0ec74d6 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_law.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_law" +"description": "The following are multiple choice questions (with answers) about professional\ + \ law.\n\n" +"tag": "mmlu_flan_cot_zeroshot_humanities" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_professional_law" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_psychology.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_psychology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48758ef76eaf72e4236a8569e041ea03e6626e67 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_professional_psychology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "professional_psychology" +"description": "The following are multiple choice questions (with answers) about professional\ + \ psychology.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_professional_psychology" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_public_relations.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_public_relations.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62a56a4478bf9eafbcf1a8034abfeea6240e99ca --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_public_relations.yaml @@ -0,0 +1,6 @@ +"dataset_name": "public_relations" +"description": "The following are multiple choice questions (with answers) about public\ + \ relations.\n\n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_public_relations" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_sociology.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_sociology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36b4711831ef6fafde0915178e28513692f9c8d5 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_sociology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "sociology" +"description": "The following are multiple choice questions (with answers) about sociology.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_social_sciences" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_sociology" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_virology.yaml b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_virology.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8e427612f45461a5d873edbafb3d6e0eba4e9f1 --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/mmlu_virology.yaml @@ -0,0 +1,6 @@ +"dataset_name": "virology" +"description": "The following are multiple choice questions (with answers) about virology.\n\ + \n" +"tag": "mmlu_flan_cot_zeroshot_other" +"include": "_mmlu_flan_cot_zeroshot_template_yaml" +"task": "mmlu_flan_cot_zeroshot_virology" diff --git a/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72246935de8cf0cf8b256fd1e6c87dfbbb90a2ad --- /dev/null +++ b/Prism/Dream/Dream_Prism/eval_instruct/lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py @@ -0,0 +1,112 @@ +import re +import sys +import unicodedata + +from lm_eval.filters.extraction import RegexFilter + + +class MultiChoiceRegexFilter(RegexFilter): + """ """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, + ) -> None: + """ + regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure + - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. + - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. + group_select: Selects the (group_select)th match from the findall result. + ignore_case: Ignores the case during step 1 matching + ignore_punctuation: Remove the punctuation during step 1 matching + regexes_to_ignore: Remove these regexes during step 1 matching + """ + super().__init__(regex_pattern, group_select, fallback) + self.ignore_case = ignore_case + self.ignore_punctuation = ignore_punctuation + self.regexes_to_ignore = regexes_to_ignore + + def apply(self, resps, docs): + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + + def find_match(regex, resp, convert_dict={}): + match = regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m][0] + match = match.strip() + if match and match in convert_dict: + match = convert_dict[match] + return match + + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) + + def filter_ignores(st): + if self.regexes_to_ignore is not None: + for s in self.regexes_to_ignore: + st = re.sub(s, "", st) + + if self.ignore_case: + st = st.lower() + + if self.ignore_punctuation: + # https://stackoverflow.com/a/266162 + st = st.translate(punct_tbl) + return st + + filtered_resps = [] + + for r, doc in zip(resps, docs): + fallback_regexes = [] + choice_to_alpha = {} + next_alpha = "A" + + without_paren_fallback_regexes = [] + without_paren_to_target = {} + + choices = doc["choices"] + for c in choices: + m = filter_ignores(c.strip()) + fallback_regexes.append(f"{re.escape(m)}") + choice_to_alpha[m] = f"({next_alpha})" + + without_paren_fallback_regexes.append(next_alpha) + without_paren_to_target[next_alpha] = f"({next_alpha})" + + next_alpha = chr(ord(next_alpha) + 1) + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + f":[\s]*({without_paren_fallback_regex})" + ) + + filtered = [] + for resp in r: + match = find_match(self.regex, resp) + if not match: + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) + if not match: + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) + if not match: + match = self.fallback + filtered.append(match) + filtered_resps.append(filtered) + + return filtered_resps diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/httpx-0.28.1.dist-info/licenses/LICENSE.md b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/httpx-0.28.1.dist-info/licenses/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..ab79d16a3f4c6c894c028d1f7431811e8711b42b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/httpx-0.28.1.dist-info/licenses/LICENSE.md @@ -0,0 +1,12 @@ +Copyright © 2019, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7052aab265a198be6fcf179de8a050efc9064916 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/_cli_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/_cli_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..508c15bdcc3a0d5d916dedcdb487500d819a6a1e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/_cli_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/auth.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/auth.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32dc374c96d1bd7fbe8e529bbf70b46e5e41ee6c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/auth.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97a23525936b6f3a736c4bbed746e95c74325b7b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/download.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/download.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1c120b990638b27eed864a9efb0f44e5470e359 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/download.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/hf.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/hf.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fafe92b741f54dd9c69adcb27076a05a988265d2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/hf.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/jobs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/jobs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64d5774c77f3b9e2e43bc4a489ea22a3543192a6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/jobs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/lfs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/lfs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66bc3a16166adc2bbeadb034162f2606ae84d3e4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/lfs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/repo.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/repo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a8ff814f8baf1caca60312bb2867890ab4d8cac Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/repo.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/repo_files.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/repo_files.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2f50e22d506642084f7c16d1a81bad288105c87 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/repo_files.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/system.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/system.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..319a7322a60b138f74f60480e9934f52d9122fa3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/system.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/upload.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/upload.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30fa91bbd1c35887b7a572016e078055db8023b7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/upload.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/upload_large_folder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/upload_large_folder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..207d5218951e7a77324d6eed5a35d5de7472237a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/cli/__pycache__/upload_large_folder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bafceedc19aa7f115250c265f18e76d4074aa52 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0de1360f40f38fefbcfaaf12103232b2a1dc27db Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95938771a901dbef6daa7bb6eb81eeaa3b675bbc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/env.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/env.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..302313fdeb8e35f609fa73117350139034a48893 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/env.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b75a9552ed5b535c9869386d0b53bb7f4809702a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f28f5e6b6f65b811fbb0f5ec8acbd3f6c3462ff2 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/repo.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/repo.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afbb21bacb040c8de626b65a33bd2f5b43b23d46 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/repo.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/repo_files.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/repo_files.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35eda32279b50f93029d1f976e60fe34e0a3c55d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/repo_files.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b8b943f3fefb7cacf76a3ded92e2dc9712add29 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/tag.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/tag.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87441fee943202f71b2867061b03e316b1eacab1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/tag.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6e2df5873651415e4ab6f60fc700c1e98892ba8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/upload_large_folder.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/upload_large_folder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3dfa208b22c681e15a451d212063cc034d692f1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/upload_large_folder.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/user.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/user.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11dd8348f08bd512089cbd802efdabba61ec04b9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/user.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/version.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42db0a3c18df8d83d279c40c857bb21292ab1363 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/huggingface_hub/commands/__pycache__/version.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/cpp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..862cef30dba49f4341a3c980845fdb7a2c1cbcd5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/cpp.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing_extensions import assert_never + +from torchgen import local +from torchgen.api.types import ( + ArgName, + ArrayCType, + ArrayRefCType, + BaseCType, + BaseTypeToCppMapping, + Binding, + boolT, + ConstRefCType, + CType, + dimnameListT, + intArrayRefT, + iTensorListRefT, + ListCType, + longT, + MutRefCType, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalSymIntArrayRefT, + scalarT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorOptionsT, + tensorT, + TupleCType, + VectorCType, + voidT, +) +from torchgen.model import ( + Argument, + Arguments, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# This file describes the translation of JIT schema to the public C++ +# API, which is what people use when they call functions like at::add. +# +# Prominent characteristics of the C++ API: +# +# - dtype, layout, device and pin_memory are collected into +# a single C++ type TensorOptions (the native functions API +# also has this, but tensor options is really most relevant +# for the C++ API; it makes calling kwarg factory functions +# pleasant) +# +# - defaulting lives here (in fact, the dispatcher is completely +# oblivious of defaults!) +# +# BTW: policy on name collisions: we try not to have types with +# collisions, but functions are fair game to collide + + +def name( + func: FunctionSchema, + *, + faithful_name_for_out_overloads: bool = False, + symint_overload: bool = False, +) -> str: + name = str(func.name.name) + if symint_overload: + name += "_symint" + if func.is_out_fn(): + if faithful_name_for_out_overloads: + name += "_outf" + else: + name += "_out" + + return name + + +# Translation of "value types" in JIT schema to C++ API type. Value +# types look the same no matter if they are argument types or return +# types. Returns None if the type in question is not a value type. +def valuetype_type( + t: Type, + *, + binds: ArgName, + mutable: bool = True, + symint: bool = False, +) -> NamedCType | None: + if isinstance(t, BaseType): + if t.name in (BaseTy.Tensor, BaseTy.Scalar): + return None + elif str(t) == "SymInt": + if symint: + return NamedCType(binds, BaseCType(SymIntT)) + else: + return NamedCType(binds, BaseCType(longT)) + # All other BaseType currently map directly to BaseCppTypes. + return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) + elif isinstance(t, OptionalType): + elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) + if elem is None: + return None + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if str(t.elem) == "bool": + assert t.size is not None + return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size)) + else: + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# If remove_non_owning_ref_types is set, we'll guarantee that the output CType is not a non-owning reference type. +# For example, we'll return std::vector instead of IntArrayRef. +# See Note [translation from C++ reference to value types] +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = False, +) -> NamedCType: + # If it's a value type, do the value type translation + r = valuetype_type( + t, + binds=binds, + mutable=mutable, + symint=symint, + ) + if r is not None: + return r + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(BaseCType(tensorT))) + else: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType( + binds, MutRefCType(BaseCType(tensorT)) + ) # TODO: fix this discrepancy + else: + return NamedCType( + binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) + ) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt": + if symint: + return NamedCType(binds, BaseCType(optionalSymIntArrayRefT)) + else: + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + # TODO: remove these special cases, ArrayRef fallthrough works fine + if str(t.elem) == "int": + if remove_non_owning_ref_types: + return NamedCType(binds, VectorCType(BaseCType(longT))) + else: + return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "SymInt": + if remove_non_owning_ref_types: + if symint: + return NamedCType(binds, VectorCType(BaseCType(SymIntT))) + else: + return NamedCType(binds, VectorCType(BaseCType(longT))) + else: + if symint: + return NamedCType(binds, BaseCType(symIntArrayRefT)) + else: + return NamedCType(binds, BaseCType(intArrayRefT)) + if str(t.elem) == "Tensor": + if local.use_ilistref_for_tensor_lists(): + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) + else: + return NamedCType(binds, BaseCType(tensorListT)) + elif str(t.elem) == "Scalar": + return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) + elif str(t.elem) == "Dimname": + return NamedCType(binds, BaseCType(dimnameListT)) + elif str(t.elem) == "Tensor?": + return NamedCType( + binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) + ) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Translate a JIT argument into its C++ type +def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds) + + +# Translation of a (non-multi) return type from JIT to C++ +# N.B: returntype_type returns a CType, not a NamedCType. +# This is mostly because of the mismatch between return types and return names. +# e.g. a function with a return type of 'void' has 0 return names, +# and a function with a return type of 'std::tuple' has >1 return name. +def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: + # placeholder is ignored + # NB: symint is ALWAYS respected for return types. So symint argument + # here is IGNORED + r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) + if r is not None: + return r.type + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + if mutable: + if local.use_const_ref_for_mutable_tensors(): + return ConstRefCType(BaseCType(tensorT)) + else: + return MutRefCType(BaseCType(tensorT)) + else: + # Note [Tensor Copy Returns] + # Currently, we use "Argument.is_write" to determine + # whether or not Tensor return types should be copies or references. + # If that ever changes, take a look at other locations of this note! + return BaseCType(tensorT) + elif t.name == BaseTy.Scalar: + return BaseCType(scalarT) + elif isinstance(t, ListType): + assert not mutable, ( + "Native functions should never return a mutable tensor list. They should return void." + ) + elem = returntype_type(t.elem, mutable=False) + assert t.size is None, f"fixed size list returns not supported: {t}" + return VectorCType(elem) + elif isinstance(t, OptionalType): + elem = returntype_type(t.elem, mutable=mutable) + if str(t.elem) == "Tensor": + return OptionalCType(elem) + + raise AssertionError(f"unrecognized return type {t}") + + +# Translation of a single return to its C++ type +def return_type(r: Return, *, symint: bool = False) -> CType: + return returntype_type(r.type, mutable=r.is_write, symint=symint) + + +# Translation of a full (possibly multi) return from JIT to its C++ type +def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType: + if len(rs) == 0: + return BaseCType(voidT) + elif len(rs) == 1: + return return_type(rs[0], symint=symint) + else: + return TupleCType([return_type(r, symint=symint) for r in rs]) + + +def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: + returns: list[str] = [] + for i, r in enumerate(f.func.returns): + # If we have an inplace function, the return argument is + # implicitly named self. + # TODO: Consider incorporating this into the data model + if f.func.name.name.inplace: + assert i == 0, "illegal inplace function with multiple returns" + name = "self" + # If we are out function, the name is the name of the + # corresponding output function (r.name will get recorded + # in field_name later.) + elif f.func.is_out_fn(): + name = f.func.arguments.out[i].name + # If the return argument is explicitly named... + elif r.name: + name_conflict = any( + r.name == a.name for a in f.func.schema_order_arguments() + ) + if name_conflict and not f.func.is_out_fn(): + name = f"{r.name}_return" + else: + name = r.name + # If there is no explicit name and no fallback name was passed in, we just name the output result, + # unless it's a multi-return, in which case it's result0, + # result1, etc (zero-indexed) + else: + name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" + returns.append(name) + return returns + + +JIT_TO_CPP_DEFAULT = { + "False": "false", + "True": "true", + "None": "::std::nullopt", # UGH this one is type directed + "Mean": "at::Reduction::Mean", + "[]": "{}", + "contiguous_format": "c10::MemoryFormat::Contiguous", + "long": "at::kLong", +} + + +# Convert a JIT default into C++ expression representing the default +def default_expr(d: str, t: Type, *, symint: bool) -> str: + if d == "None" and str(t) == "Tensor?": + return "{}" + if isinstance(t, BaseType) and t.name is BaseTy.str: + # Schema allows single quotes but C++ needs double + if len(d) >= 2 and d[0] == "'" and d[-1] == "'": + s = "" + i = 1 + while i + 1 < len(d): + if d[i] != "\\": + if d[i] == '"': + s += '\\"' + else: + s += d[i] + i += 1 + else: + if d[i + 1] == "'": + s += "'" + else: + s += d[i : i + 2] + i += 2 + + return f'"{s}"' + + if isinstance(t, OptionalType): + if d == "None": + return "::std::nullopt" + + return default_expr(d, t.elem, symint=symint) + + if isinstance(t, ListType): + if d.startswith("[") and d.endswith("]"): + return "{" + d[1:-1] + "}" + elif symint and d.isdigit() and str(t.elem) == "SymInt": + return f"c10::SymInt({d})" + elif t.size is None: + # NOTE: Sized lists can have scalar defaults + raise ValueError(f"Expected a list default '[...]' but found: '{d}'") + + return JIT_TO_CPP_DEFAULT.get(d, d) + + +# Convert an argument into its C++ API form + + +def argument( + a: Argument | TensorOptionsArguments | SelfArgument, + *, + cpp_no_default_args: set[str], + method: bool, + faithful: bool, + symint: bool = False, + has_tensor_options: bool, +) -> list[Binding]: + def sub_argument( + a: Argument | TensorOptionsArguments | SelfArgument, + ) -> list[Binding]: + return argument( + a, + cpp_no_default_args=cpp_no_default_args, + method=method, + faithful=faithful, + symint=symint, + has_tensor_options=has_tensor_options, + ) + + if isinstance(a, Argument): + binds: ArgName + if a.name == "memory_format" and has_tensor_options: + binds = SpecialArgName.possibly_redundant_memory_format + else: + binds = a.name + default: str | None = None + if a.name not in cpp_no_default_args and a.default is not None: + default = default_expr(a.default, a.type, symint=symint) + return [ + Binding( + nctype=argument_type(a, binds=binds, symint=symint), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, TensorOptionsArguments): + if faithful: + return ( + sub_argument(a.dtype) + + sub_argument(a.layout) + + sub_argument(a.device) + + sub_argument(a.pin_memory) + ) + else: + default = None + # Enforced by NativeFunction.__post_init__ + assert "options" not in cpp_no_default_args + if all(x.default == "None" for x in a.all()): + default = "{}" + elif a.dtype.default == "long": + default = "at::kLong" # TODO: this is wrong + return [ + Binding( + nctype=NamedCType("options", BaseCType(tensorOptionsT)), + name="options", + default=default, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + if method: + # Caller is responsible for installing implicit this in context! + return [] + else: + return sub_argument(a.argument) + else: + assert_never(a) + + +def arguments( + arguments: Arguments, + *, + faithful: bool, + symint: bool = False, + method: bool, + cpp_no_default_args: set[str], +) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + if faithful: + args.extend(arguments.non_out) + args.extend(arguments.out) + else: + args.extend(arguments.out) + args.extend(arguments.non_out) + return [ + r.no_default() if faithful else r + for a in args + for r in argument( + a, + faithful=faithful, + symint=symint, + method=method, + has_tensor_options=arguments.tensor_options is not None, + cpp_no_default_args=cpp_no_default_args, + ) + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/dispatcher.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..fcca7a60fec1829c5783197055733467fcdd63fe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/dispatcher.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING +from typing_extensions import assert_never + +from torchgen.api import cpp +from torchgen.api.types import ArgName, Binding, CType, NamedCType +from torchgen.model import ( + Argument, + FunctionSchema, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import concatMap + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# This file describes the translation of JIT schema to the dispatcher +# API, the *unboxed* calling convention by which invocations through +# the dispatcher are made. Historically, the dispatcher API matched +# the C++ API, but with the establishment of the boxed API, we've +# made changes to the dispatcher API to so that the unboxed API +# better aligns with the boxed API. The dispatcher API hooks heavily +# into our template based boxing/unboxing machinery, so changes +# to this convention will usually need template updates too. +# +# Prominent characteristics of the dispatcher API: +# +# - dtype, layout, device and pin_memory are represented as separate +# arguments. +# + + +def name(func: FunctionSchema) -> str: + return cpp.name(func) + + +def argumenttype_type( + t: Type, + *, + mutable: bool, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, +) -> NamedCType: + # This is a faux amis. If it makes sense in the future to add + # more special cases here, or invert things so cpp.argument_type + # calls this, or just completely inline the function, please do + # it. + return cpp.argumenttype_type( + t, + mutable=mutable, + binds=binds, + symint=symint, + remove_non_owning_ref_types=remove_non_owning_ref_types, + ) + + +def argument_type( + a: Argument, + *, + binds: ArgName, + remove_non_owning_ref_types: bool = False, + symint: bool = True, +) -> NamedCType: + return argumenttype_type( + a.type, + mutable=a.is_write, + binds=binds, + remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, + ) + + +def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType: + # At present, there is no difference. But there could be! + return cpp.returns_type(rs, symint=symint) + + +def jit_arguments(func: FunctionSchema) -> list[Argument]: + def to_argument( + a: Argument | TensorOptionsArguments | SelfArgument, + ) -> list[Argument]: + if isinstance(a, Argument): + return [a] + elif isinstance(a, SelfArgument): + return [a.argument] + elif isinstance(a, TensorOptionsArguments): + return [a.dtype, a.layout, a.device, a.pin_memory] + else: + assert_never(a) + + return list( + concatMap( + to_argument, + itertools.chain( + func.arguments.positional, func.arguments.kwarg_only, func.arguments.out + ), + ) + ) + + +def argument( + a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True +) -> Binding: + return Binding( + nctype=argument_type( + a, + binds=a.name, + remove_non_owning_ref_types=remove_non_owning_ref_types, + symint=symint, + ), + name=a.name, + argument=a, + ) + + +def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]: + return [argument(a, symint=symint) for a in jit_arguments(func)] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/functionalization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/functionalization.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b46b5f14760b2eca447536a1795ade807f89d5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/functionalization.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from torchgen.api import dispatcher +from torchgen.api.types import ( + BaseCppType, + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + longT, + NamedCType, + tensorT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + NativeFunction, + NativeFunctionsViewGroup, +) + + +# This file describes the translation of JIT schema to API's used +# when creating `ViewMeta` specializations that are used by the functionalization pass. +# These API's mostly follow the dispatcher API, with one difference: +# - While the forward function just directly calls into the at::_ops API +# (following the dispatcher convention), the logic here for the reverse function +# is responsible for generating both the call-site, and the declarations +# (which are implemented manually in the at::functionalization::impl namespace). + +# Define some specific lambda input arguments. +base_binding = Binding( + name="base", + nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) + +has_symbolic_inputs_binding = Binding( + name="has_symbolic_inputs", + nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)), + argument=Argument( + name="has_symbolic_inputs", + type=BaseType(BaseTy.bool), + default=None, + annotation=None, + ), + default=None, +) +mutated_view_binding = Binding( + name="mutated_view", + nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), + argument=Argument( + name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None + ), + default=None, +) +out_index_binding = Binding( + name="out_index", + nctype=NamedCType(name="out_index", type=BaseCType(longT)), + argument=Argument( + name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None + ), + default=None, +) +reapply_views_binding = Binding( + name="reapply_views", + nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), + argument=Argument( + name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None + ), + default=None, +) + +InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") +inverse_return_mode_binding = Binding( + name="inverse_return_mode", + nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), + argument=Argument( + name="inverse_return_mode", + # NB: not actually a bool but it doesn't matter because this isn't used + type=BaseType(BaseTy.bool), + default=None, + annotation=None, + ), + default=None, +) + + +# Name of the `ViewMeta` specialization class created. +def classname(func: FunctionSchema, with_namespace: bool = False) -> str: + namespace = "at::functionalization::" if with_namespace else "" + return f"{namespace}{func.name.unambiguous_name()}_ViewMeta" + + +# Name of the operation called inside the `forward`/`reverse` implementations. +def name( + g: NativeFunctionsViewGroup, + *, + is_reverse: bool, + include_namespace: bool, + reapply_views: bool | None = None, +) -> str: + if reapply_views is None: + # reapply_views is only important for the fwd lambda, + # since we always plumb the runtime "reapply_views" argument into the reverse function. + assert is_reverse + if is_reverse: + return reverse_name(g.view, include_namespace) + # in the forward case, we just directly call into the at::_ops API (so we always need the namespace) + assert include_namespace + assert g.view_copy is not None + api_name = ( + g.view.func.name.unambiguous_name() + if reapply_views + else g.view_copy.func.name.unambiguous_name() + ) + return f"at::_ops::{api_name}::call" + + +def reverse_name(f: NativeFunction, include_namespace: bool) -> str: + # for the reverse: we plumb the "reapply_views" flag into that function and support + # both copy and non-copy variants. (We could avoid doing that, but that would require + # writing out twice as many view inverse functions). + api_name = f.func.name.unambiguous_name() + # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't) + if include_namespace: + return f"at::functionalization::FunctionalInverses::{api_name}_inverse" + else: + return f"{api_name}_inverse" + + +def returns_type(func: FunctionSchema) -> CType: + # Assertion: all view ops return tensor-like outputs + assert len(func.returns) >= 1 + for ret in func.returns: + assert ret.type.is_tensor_like() + # However, the return type of the lambda is always an individual tensor. + # For multi-tensor outputs, each tensor needs to be tracked individually. + return BaseCType(tensorT) + + +# Checks whether `func` might return more than one value. +def is_multi_output(func: FunctionSchema) -> bool: + return len(func.returns) > 1 or ( + len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None + ) + + +# `ViewMeta` specialization constructor parameters. +def base_ctor_arguments(func: FunctionSchema) -> list[Binding]: + # All specializations are parematerized by `has_symbolic_inputs` flag. + arguments = [has_symbolic_inputs_binding] + + # If `func` might return more than 1 value, we also parameterize this specialization + # with the output index. + if is_multi_output(func): + arguments.append(out_index_binding) + + return arguments + + +# `ViewMeta` specialized class' constructor arguments. +# +# Values needed specifically by this specialization, that the base class does not need. +# Same as the class' attributes, but non-owning. +def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]: + return attributes(func, owning=False) + + +# `ViewMeta` specialized class' non-static member data. +# +# Essential data for calling the instance's `forward` and `reverse functions. You can +# think of them as values that should be captured from the functionalization kernel. +def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]: + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + return [ + reapply_views_binding, + inverse_return_mode_binding, + *[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]], + ] + + +def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: + args = func.arguments.flat_all + assert args[0].type == BaseType(BaseTy.Tensor) + non_self_args = args[1:] + # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. + # Both of these follow the dispatcher API. + non_self_bindings = [dispatcher.argument(a) for a in non_self_args] + if not is_reverse: + # the forward lambda swaps out the original tensor argument with the lambd arg "base" + return [base_binding] + non_self_bindings + else: + # the reverse lambda does the same, but with an additional "mutated_view" arg + # additionally, we have a calling convention: for view ops that return multiple tensor outputs + # their corresponding view_inverse function takes in an additional index argument. + if is_multi_output(func): + return [ + base_binding, + mutated_view_binding, + inverse_return_mode_binding, + out_index_binding, + ] + non_self_bindings + else: + return [ + base_binding, + mutated_view_binding, + inverse_return_mode_binding, + ] + non_self_bindings diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/lazy.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..1d308afd8136a4e4d3c0b5eb1b89fcbd00c0a5c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/lazy.py @@ -0,0 +1,468 @@ +from __future__ import annotations + +from typing import Any + +from torchgen.api.types import ( + BaseCppType, + BaseCType, + boolT, + CType, + deviceT, + doubleT, + generatorT, + layoutT, + ListCType, + longT, + memoryFormatT, + NamedCType, + OptionalCType, + scalarT, + scalarTypeT, + stringT, + SymIntT, + VectorCType, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + OperatorName, + OptionalType, + Return, + TensorOptionsArguments, + Type, +) + + +_valueT: BaseCppType | None = None + + +# A ValueT is an IR type which represents the computation of a Tensor. In other +# words, a PyTorch user will do operations on lazy tensors, and each output lazy +# tensor internally tracks a ValueT representing the IR node that would have +# actually produced the value of this tensor for real. +# +# This is configurable because different lazy tensor backends (LTC vs XLA) will +# have different IR representations. (Though, arguably, after unification they +# shouldn't!) +def getValueT() -> BaseCppType: + global _valueT + if not _valueT: + raise NotImplementedError( + "The value type needs to be set with setValueT() in run_gen_lazy_tensor()" + ) + + return _valueT + + +def setValueT(val: BaseCppType) -> None: + global _valueT + _valueT = val + + +# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object, +# making it easier to represent special properties of an arg. +tensorListValueT = BaseCppType("torch::lazy", "Value") + + +def process_ir_type( + typ: Type, properties: LazyIrProperties, *, symint: bool +) -> BaseCType | VectorCType | OptionalCType | ListCType: + """ + This function takes a type from NativeFunctions and converts it for use with + lazy tensor codegen. + + Type conversion for lazy currently consists of + (1) changing at::Tensors into lazy::Values + (2) wrapping everything in a BaseCType + (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) + + (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) + There is special handling for Optional[Tensor] or list[Tensor], etc- hence 'tensor-like' + + This is incomplete- there are assertions in places that it's expected to need to add + more types as the codegen is used with more operators. + """ + if isinstance(typ, BaseType): + if typ.name == BaseTy.Tensor: + return BaseCType(getValueT()) + elif typ.name == BaseTy.Scalar: + if properties.TreatScalarsAsConstants: + return BaseCType(scalarT) + # at::scalar has special handling, + # and is wrapped in an lazy::Value just like at::tensor + return BaseCType(getValueT()) + elif typ.name == BaseTy.ScalarType: + return BaseCType(scalarTypeT) + elif typ.name == BaseTy.int: + return BaseCType(longT) + elif typ.name == BaseTy.SymInt: + if symint: + return BaseCType(getValueT()) + else: + return BaseCType(longT) + elif typ.name == BaseTy.bool: + return BaseCType(boolT) + elif typ.name == BaseTy.float: + return BaseCType(doubleT) + elif typ.name == BaseTy.str: + return BaseCType(stringT) + elif typ.name == BaseTy.Device: + return BaseCType(deviceT) + elif typ.name == BaseTy.Generator: + return BaseCType(generatorT) + elif typ.name == BaseTy.Layout: + return BaseCType(layoutT) + elif typ.name == BaseTy.MemoryFormat: + return BaseCType(memoryFormatT) + else: + raise AssertionError(f"TODO add support for type {repr(typ)}") + elif isinstance(typ, OptionalType): + return OptionalCType(process_ir_type(typ.elem, properties, symint=symint)) + elif isinstance(typ, ListType): + if str(typ.elem) == "Tensor?": + # TODO(whc) is this actually correct? or should it use a Vector like above + return ListCType(OptionalCType(BaseCType(getValueT()))) + elif str(typ.elem) == "Tensor": + # this is a TensorList which comes in from GetTensorList as a Value + return BaseCType(tensorListValueT) + elif typ.elem == BaseType(BaseTy.SymInt): + # TODO: return a value type. The problem here is analogous to + # the problem with tensorListValueT: if you have SymInt[] you + # cannot conveniently save the list of Value directly, as nodes + # expect to save values as a vector for ALL arguments. So you + # need a separate IR node that represents all of the size nodes + # assembled into a list. I'm not an LTC dev so I don't want to + # figure it out right now. Y'all figure it out... + return VectorCType(BaseCType(longT)) + + else: + return VectorCType(process_ir_type(typ.elem, properties, symint=symint)) + else: + raise AssertionError(f"unrecognized type {repr(typ)}") + + +# TODO: Determining this based off of CType is bad; this should be computed +# from Type directly; then the same logic as process_ir_type can be used +# +# Invariant: passed typ should be an *owning* CType (e.g., we will report +# that ArrayRef is NOT a value type) +def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool: + """ + Given a type, determine if it is a Value-like type. This is equivalent to + being Tensor-like, but assumes the type has already been transformed. + """ + if isinstance(typ, BaseCType): + # I am regretting my naming conventions, but now we are wrapping at::scalar in + # lazy value, while preserving other 'scalar' types as scalars in the IR + treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants + return ( + typ.type == getValueT() + or (typ.type == scalarT and not treat_scalars_as_constants) + or typ.type == SymIntT + ) + elif typ == VectorCType(BaseCType(SymIntT)): + # TODO: report True for this + return False + elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): + return isValueType(typ.elem, properties) + return False + + +def isSymIntType(typ: Type) -> bool: + return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt + + +def isWrappedScalarType(typ: Type) -> bool: + """ + Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value. + Since we literally change the type from scalarT to valueT, information is lost. + This function helps build a list of wrapped scalars to save that information + """ + if isinstance(typ, BaseType): + # I am regretting my naming conventions, but now we are wrapping at::scalar in + # lazy value, while preserving other 'scalar' types as scalars in the IR + return typ.name == BaseTy.Scalar + elif isinstance(typ, (OptionalType, ListType)): + return isWrappedScalarType(typ.elem) + return False + + +# TODO: dedupe with Type.is_generator_like +def isGeneratorType(typ: Type) -> bool: + if isinstance(typ, BaseType): + return typ.name == BaseTy.Generator + elif isinstance(typ, (OptionalType)): + return isGeneratorType(typ.elem) + return False + + +# This class caches a few derived properties computed from an Argument +# and LazyIrProperties +class LazyArgument: + name: str + orig_type: Type + lazy_type_: CType | None + is_wrapped_scalar: bool + is_generator: bool + # TODO: this is lies, it is false for symint list + is_symint_or_list: bool + + # Whether or not we are treating this as symint or not + symint: bool + + # true if this argument is or contains a lazy IR value + is_lazy_value: bool + + def __init__( + self, arg: Argument, properties: LazyIrProperties, *, symint: bool + ) -> None: + self.name = arg.name + self.orig_type = arg.type + self.symint = symint + self.is_optional = isinstance(arg.type, OptionalType) + self.is_generator = isGeneratorType(arg.type) + self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint) + self.is_wrapped_scalar = isWrappedScalarType(arg.type) + self.is_symint_or_list = symint and ( + isSymIntType(arg.type) + or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem)) + # TODO: lists of symints are not currently treated as value types + # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem)) + ) + + self.is_lazy_value = isValueType(self.lazy_type, properties) + + @property + def lazy_type(self) -> CType: + assert self.lazy_type_ is not None, ( + f"Attempted to access lazy_type for invalid argument {self.name}" + ) + return self.lazy_type_ + + +class LazyIrProperties: + """Collection of properties for an IR node + + The property groups are listed below. Each group is mutually + exclusive, meaning that only one property from each group can be True + at any one time. The properties can be accessed as if they were normal + attributes. The mutual exclusivity is automatically handled. + """ + + Properties: tuple[tuple[str, ...], ...] = ( + ( + "ShapePrecompute", # Assume shape has been precomputed + "ShapeCompute", # Need to compute the shape on construction + "ShapeCache", # Utilize the shape cache to defer computation + ), + ( + "Lower", # Codegen full lower function + "LowerDeclOnly", # Codegen only lower function declaration + ), + ( + "CanBeReused", # Codegen full reuse function + "CanBeReusedDeclOnly", # Codegen only reuse function declaration + ), + ( + "CreateFn", # Codegen full create function + "CreateFnDeclOnly", # Codegen only create function declaration + ), + ( + "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values + ), + ) + + def __init__(self, *default_properties: str) -> None: + properties: dict[tuple[str, ...], str | None] = dict.fromkeys( + LazyIrProperties.Properties + ) + self.__dict__["properties"] = properties + for p in default_properties: + setattr(self, p, True) + + def __getattr__(self, key: str) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + return properties[values] == key + + return self.__getattribute__(key) + + def __setattr__(self, key: str, value: Any) -> Any: + properties = self.__dict__["properties"] + for values in LazyIrProperties.Properties: + if key in values: + properties[values] = key if value else None + return value + + raise KeyError(f"Invalid property: {key}") + + +# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. +# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), +# but carries type information from a native FunctionSchema modified for use with IR nodes, +# and preserving original argument names. +# +# TODO: This is not idiomatic with how other torchgen APIs transform on schema. +class LazyIrSchema: + # The name of the operator this function schema describes. + name: OperatorName + + positional_args: tuple[LazyArgument, ...] + keyword_args: tuple[LazyArgument, ...] + + # TODO: Need to handle collisions with argument names at some point + returns: tuple[Return, ...] + + # if this schema has a Generator arg, list its orig ctype/name but don't + # build a LazyArgument since lazy IR doesn't support it + generator_arg: NamedCType | None = None + + # original function schema + func: FunctionSchema + + # Whether or not we are code-genning for SymInt or not + symint: bool + + properties: LazyIrProperties = LazyIrProperties( + # default properties + "ShapePrecompute", + "Lower", + "CanBeReused", + ) + opkind: str | None = None + + def __init__( + self, + func: FunctionSchema, + properties: LazyIrProperties | None = None, + *, + symint: bool, + ) -> None: + if properties: + self.properties = properties + + self.func = func + self.symint = symint + positional_args: list[LazyArgument] = [] + for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: + if arg_field == "self_arg" and func.arguments.self_arg is not None: + arg = func.arguments.self_arg.argument + positional_args.append( + LazyArgument(arg, self.properties, symint=symint) + ) + elif getattr(func.arguments, arg_field) is not None: + positional_args.extend( + LazyArgument(arg, self.properties, symint=symint) + for arg in getattr(func.arguments, arg_field) + ) + self.positional_args = tuple(positional_args) + + keyword_args: list[LazyArgument] = [] + for arg_field in [ + "pre_tensor_options_kwarg_only", + "tensor_options", + "post_tensor_options_kwarg_only", + "out", + ]: + curr_args = getattr(func.arguments, arg_field) + if curr_args is not None: + if isinstance(curr_args, TensorOptionsArguments): + curr_args = curr_args.all() + for arg in curr_args: + if isGeneratorType(arg.type): + assert self.generator_arg is None, ( + "We expect there is only one generator arg" + ) + self.generator_arg = NamedCType( + arg.name, + arg.type, # type:ignore[arg-type] + ) + keyword_args.extend( + LazyArgument(arg, self.properties, symint=symint) + for arg in curr_args + ) + self.keyword_args = tuple(keyword_args) + self.name = func.name + self.returns = func.returns + + @property + def node_name(self) -> str: + """ + Return camel-case version of op in node. + + Note: This function also appends any `overload_name` in the operation. + For example, if the op is `bitwise_and.Tensor`, the returned name + will be `BitwiseAndTensor`. + """ + op_name = f"{self.name.name}_{self.name.overload_name}".lower() + return "".join(word.capitalize() or "" for word in op_name.split("_")) + + @property + def aten_name(self) -> str: + return str(self.name.name) + + @property + def base_name(self) -> str: + return f"{self.name.name.base}" + + def filtered_args( + self, + positional: bool = True, + keyword: bool = True, + values: bool = True, + scalars: bool = True, + generator: bool = True, + ) -> list[LazyArgument]: + # This function maintains the sorted order of arguments but provides different filtered views. + # Some parts of the code care about kwargs vs args (TS lowerings), + # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. + # Generators are special cased, as they are needed for fallback/shape-inference but not supported + # in TS lowerings and therefore also omitted from lazy IR. + args: list[LazyArgument] = [] + if positional: + args.extend(self.positional_args) + if keyword: + args.extend(self.keyword_args) + + if values and scalars and generator: + return args + elif values and scalars: + return [a for a in args if not a.is_generator] + elif values: + return [a for a in args if a.is_lazy_value] + elif scalars: + return [ + a + for a in args + if not a.is_lazy_value and (generator or not a.is_generator) + ] + + return [] + + @property + def positional_values(self) -> list[LazyArgument]: + return self.filtered_args( + positional=True, keyword=False, values=True, scalars=False + ) + + @property + def positional_scalars(self) -> list[LazyArgument]: + return self.filtered_args( + positional=True, keyword=False, values=False, scalars=True + ) + + @property + def keyword_values(self) -> list[LazyArgument]: + return self.filtered_args( + positional=False, keyword=True, values=True, scalars=False + ) + + @property + def keyword_scalars(self) -> list[LazyArgument]: + return self.filtered_args( + positional=False, keyword=True, values=False, scalars=True + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/native.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/native.py new file mode 100644 index 0000000000000000000000000000000000000000..632216704d2d47606b977d487335ca196e2e1842 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/native.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing_extensions import assert_never + +from torchgen import local +from torchgen.api import cpp +from torchgen.api.types import ( + ArgName, + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + deviceT, + layoutT, + ListCType, + MutRefCType, + NamedCType, + OptionalCType, + scalarT, + scalarTypeT, + tensorT, +) +from torchgen.model import ( + Argument, + FunctionSchema, + Return, + SelfArgument, + TensorOptionsArguments, + Type, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# This file describes the translation of JIT schema to the native functions API. +# This looks a lot like the C++ API (which makes historical sense, because the +# idea was you wrote native functions to implement functions in the C++ API), +# but over time we have evolved the C++ API without actually changing our +# native:: kernels. The intention is to make native API and dispatcher API +# line up as closely as possible, since this results in the least overhead +# (no translation is needed from dispatcher API to native API). +# +# NB: this is symint aware, you will get the non-SymInt variant for some +# dispatch entries and SymInt for others. + + +def name(func: FunctionSchema) -> str: + name = str(func.name.name) + # TODO: delete this! + if func.is_out_fn(): + name += "_out" + if func.name.overload_name: + name += f"_{func.name.overload_name}" + return name + + +def argumenttype_type( + t: Type, *, mutable: bool, binds: ArgName, symint: bool +) -> NamedCType: + if str(t) == "Tensor?": + tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) + if mutable and not local.use_const_ref_for_mutable_tensors(): + return NamedCType(binds, MutRefCType(tensor_type)) + else: + return NamedCType(binds, ConstRefCType(tensor_type)) + elif str(t) == "Tensor?[]": + return NamedCType( + binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) + ) + elif str(t) == "Scalar": + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif str(t) == "Scalar?": + return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) + return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) + + +def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: + return cpp.returns_type(rs, symint=symint) + + +def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) + + +def argument( + a: Argument | SelfArgument | TensorOptionsArguments, + *, + is_out: bool, + symint: bool, +) -> list[Binding]: + # Ideally, we NEVER default native functions. However, there are a number + # of functions that call native:: directly and rely on the defaulting + # existing. So for BC, we generate defaults for non-out variants (but not + # for out variants, where it is impossible to generate an appropriate + # default) + should_default = not is_out + if isinstance(a, Argument): + default: str | None = None + if should_default and a.default is not None: + default = cpp.default_expr(a.default, a.type, symint=symint) + return [ + Binding( + nctype=argument_type(a, binds=a.name, symint=symint), + name=a.name, + default=default, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + # Erase SelfArgument from the distinction + return argument(a.argument, is_out=is_out, symint=symint) + elif isinstance(a, TensorOptionsArguments): + default = None + if should_default: + default = "{}" + # TODO: Not sure why the arguments assigned here are for + # TensorOptionsArguments and not the constituent pieces. It seems + # to matter + return [ + Binding( + nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), + name="dtype", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), + name="layout", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), + name="device", + default=default, + argument=a, + ), + Binding( + nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), + name="pin_memory", + default=default, + argument=a, + ), + ] + else: + assert_never(a) + + +def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(func.arguments.non_out) + args.extend(func.arguments.out) + return [ + r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/structured.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/structured.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e14e5b69e6421fce5ddd247958876061d72b2c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/structured.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing_extensions import assert_never + +from torchgen.api import cpp +from torchgen.api.types import ( + ArgName, + ArrayRefCType, + BaseCType, + Binding, + ConstRefCType, + dimnameListT, + intArrayRefT, + iOptTensorListRefT, + iTensorListRefT, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalTensorRefT, + scalarT, + tensorT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + ListType, + NativeFunctionsGroup, + OptionalType, + SelfArgument, + TensorOptionsArguments, + Type, +) + + +# This file describes the translation of JIT schema to the structured functions API. +# This is similar to native API, but a number of historical problems with native +# API have been fixed. + + +# Translation of types occurring in JIT arguments to a C++ argument type. +# NB: For now, mutable doesn't do anything; but it could if we make +# some more nominal types +def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: + # If it's a value type, do the value type translation + # NB: structured kernels ALWAYS have symint off, since they involve actual + # kernels that require real ints. The one exception is the + # CompositeExplicitAutograd and the meta function (which could + # hypothetically be SymInt), but for simplicity we plan for these to just + # be handled in Python + r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) + if r is not None: + return r + + if isinstance(t, BaseType): + if t.name == BaseTy.Tensor: + return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) + elif t.name == BaseTy.Scalar: + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + else: + raise AssertionError(f"base type should have been value type {t}") + elif isinstance(t, OptionalType): + if t.elem == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(optionalTensorRefT)) + elif t.elem == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(optionalScalarRefT)) + elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": + return NamedCType(binds, BaseCType(optionalIntArrayRefT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, OptionalCType(elem.type)) + elif isinstance(t, ListType): + if t.elem == BaseType(BaseTy.Tensor): + return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) + elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): + return NamedCType(binds, BaseCType(iOptTensorListRefT)) + # TODO: delete these special cases; see torchgen.api.cpp--these + # must be changed in tandem, but there are problems; see + # https://github.com/pytorch/pytorch/pull/51485 + elif str(t.elem) == "int": + return NamedCType(binds, BaseCType(intArrayRefT)) + elif str(t.elem) == "Dimname": + return NamedCType(binds, BaseCType(dimnameListT)) + elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) + return NamedCType(binds, ArrayRefCType(elem.type)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: + return argumenttype_type(a.type, mutable=a.is_write, binds=binds) + + +# returns_type intentionally omitted, because structured kernels never "return"; +# instead, they always indirectly report their outputs (in the case of a meta +# function, by calling set_output; in the case of an impl function, by writing +# directly into the provided out argument). + + +# Structured kernels are never defaulted +def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]: + if isinstance(a, Argument): + return [ + Binding( + nctype=argument_type(a, binds=a.name), + name=a.name, + default=None, + argument=a, + ) + ] + elif isinstance(a, SelfArgument): + return argument(a.argument) + elif isinstance(a, TensorOptionsArguments): + raise AssertionError("structured kernels don't support TensorOptions yet") + else: + assert_never(a) + + +def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + + if g.out.precomputed: + # A list of parameters for the impl function with + # certain parameters replaced with precomputed counterparts + # as specified in native_functions.yaml. + non_out_args_replaced: list[ + Argument | TensorOptionsArguments | SelfArgument + ] = [] + for a in g.out.func.arguments.non_out: + if isinstance(a, Argument) and a.name in g.out.precomputed.replace: + # If a is in precompute.replace, append the parameters + # that should replace it onto non_out_args_replaced. + non_out_args_replaced.extend(g.out.precomputed.replace[a.name]) + else: + # If not, push a as it is. + non_out_args_replaced.append(a) + + args.extend(non_out_args_replaced) + # g.out.precomputed.add is the list of parameters that are added + # without replacement after the non out args and just before the out args + args.extend(g.out.precomputed.add) + else: + args.extend(g.out.func.arguments.non_out) + + args.extend(g.out.func.arguments.out) + return [r for arg in args for r in argument(arg)] + + +def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(g.functional.func.arguments.non_out) + return [r for arg in args for r in argument(arg)] + + +def out_arguments(g: NativeFunctionsGroup) -> list[Binding]: + args: list[Argument | TensorOptionsArguments | SelfArgument] = [] + args.extend(g.out.func.arguments.out) + return [r for arg in args for r in argument(arg)] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/translate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..f98ce09bbfafb875a619ea01eae7b6f82d76ef71 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/translate.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +from typing import NoReturn, TYPE_CHECKING + +from torchgen.api.types import ( + ArrayRefCType, + BaseCType, + Binding, + boolT, + ConstRefCType, + deviceT, + Expr, + intArrayRefT, + iOptTensorListRefT, + layoutT, + ListCType, + longT, + memoryFormatT, + MutRefCType, + NamedCType, + opmath_t, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalSymIntArrayRefT, + optionalTensorRefT, + scalar_t, + scalarT, + scalarTypeT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorOptionsT, + tensorT, + VectorCType, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# This file implements a small program synthesis engine that implements +# conversions between one API to another. +# +# The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType +# represents a C++ type, plus semantic information about what it represents. +# For example, consider the argument "bool pin_memory"; its normal C++ type is +# "bool", but its C++ semantic type also keeps track that this represents a +# "pin_memory"; you can't just use a random other boolean in a context where you +# need a "pin_memory"! +# +# The translator takes a list of needed NamedCTypes, and then figures out how +# to construct expressions with these NamedCTypes from the given bindings. Many +# of these expressions are trivial (I need a Tensor other; there's a Tensor +# other scope); others are more nontrivial and may require packing/unpacking. +# Some examples of non-trivial action: +# +# - Need the "dtype" binding? Well, maybe "dtype" isn't available +# in the context, instead, "options" is, and you need to extract +# it from there. (Gather) +# +# - Need the "context" binding? Well, maybe "context" isn't available +# in the context, and you need to construct it from "dtype", "device", +# etc. (Scatter) +# +# - Need the "memory_format" binding? Well, actually, it's available +# from both "memory_format" and "options", so you had better make sure +# they are consistent. (Join) + +options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) + +out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) + +longVec_ctype = VectorCType(BaseCType(longT)) +longSymVec_ctype = VectorCType(BaseCType(SymIntT)) +optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) +optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) +optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) + + +class UnsatError(RuntimeError): + pass + + +# Given a set of in-scope bindings and a set of target bindings, synthesize +# a list of expressions that uses only the in-scope bindings (bindings) that +# have all of the types of goals. You may want to use this function if +# you're generating code for a function like: +# +# void f({args}) { +# g({exprs}); // g is a different API +# } +# +# and you need to generate "exprs". +# +# Typically, a list of Bindings is convenient to get (you usually call something +# like arguments() to get them); but technically you only need less information: +# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for +# 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing +# something more complicated, e.g., tracking the set of bindings in a context, +# you may find using these smaller types more convenient. +def translate( + bindings: Sequence[Expr | Binding], + goals: Sequence[NamedCType | Binding], + *, + method: bool = False, + allow_expensive_conversions: bool = False, +) -> list[Expr]: + binding_exprs: list[Expr] = [] + for b in bindings: + if isinstance(b, Binding): + binding_exprs.append( + Expr( + expr=b.name, + type=b.nctype, + ) + ) + else: + binding_exprs.append(b) + + goal_ctypes: list[NamedCType] = [] + for g in goals: + if isinstance(g, Binding): + goal_ctypes.append(g.nctype) + else: + goal_ctypes.append(g) + + # Add all the bindings to the context + ctx: dict[NamedCType, str] = {} + for b in binding_exprs: + ctx[b.type] = b.expr + + # While we're at it, do some simple forward inference, looking through + # constructors. + # + # NB: When should you do forward inference versus backward inference? + # The general idea: + # + # - Backward inference WHEN the goal gets smaller + # - Forward inference WHEN the hypothesis gets smaller + # + # This helps ensure termination: backward inference starts with a goal + # and tries to make it simpler and simpler until it's trivial; if the + # goal can grow in size, we blow up to a really huge goal size. + # Similarly, with forward inference we take hypotheses and decompose + # them into simpler hypotheses; if hypotheses could expand in size, + # we also have potential nontermination. (In the code below, forward + # inference is only ever carried out at a single step, but you could + # imagine repeated application of forward inference being profitable.) + # + # A good starting point in the literature for exploring more about proof + # search are these lecture notes + # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf + # + # TODO: My kingdom for a pattern matcher + # https://www.python.org/dev/peps/pep-0634/ + # + # TODO: This could get us in recomputation trouble if b.expr is nontrivial. + # Fix this by implementing some sort of sharing so that if multiple + # goals share the same expression, we only compute it once. This seems + # to matter in practice as compiler is often unwilling to CSE nontrivial + # expressions like scalar.to() + t = b.type + if ( + isinstance(t, ConstRefCType) + and isinstance(t.elem, OptionalCType) + and isinstance(t.elem.elem, BaseCType) + and str(t.elem.elem.type) == "at::Tensor" + ): + ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = ( + f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" + ) + + if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): + ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = ( + f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" + ) + + if t.type == ConstRefCType(BaseCType(scalarT)): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to()" + + if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): + ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = ( + f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" + ) + + if t.type == BaseCType(scalar_t): + ctx[NamedCType(t.name, BaseCType(opmath_t))] = ( + f"static_cast({b.expr})" + ) + + # [Note: IOptTensorListRef] + if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): + ctx[NamedCType(t.name, BaseCType(iOptTensorListRefT))] = ( + f"at::IOptTensorListRef({b.expr})" + ) + + # Add implicit bindings if the generated code is inside a Tensor method + if method: + ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = ( + "const_cast(*this)" + ) + ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = ( + "const_cast(*this)" + ) + # This is better! Byte-for-byte compat + # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" + + def unsat(goal: NamedCType) -> NoReturn: + ctx_desc = "\n".join( + f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items() + ) + raise UnsatError( + f""" +Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". +When I failed, the following bindings were available in the context: + +{ctx_desc} + +This probably means there is a missing rule in the rules of torchgen.api.translate. +Check this module for more information. +""" + ) + + # A shitty backtracking search implementation. It's shitty because it + # does backtracking via stack (bad idea!) and for the most part tries to + # avoid backtracking. In particular, if + # direct=True, we won't try to do any fancy synthesis, just trivial + # conversions (e.g., "T a" is OK for "const T& a"). So all of the + # existing rules in this function simply try to solve immediately, + # and bail if things don't work out. + def solve(goal: NamedCType, *, direct: bool) -> str: + def direct_solve(goal: NamedCType) -> str: + return solve(goal, direct=True) + + if goal in ctx: + # Trivial + return ctx[goal] + + # const & is satisfied with mutable & + if isinstance(goal.type, ConstRefCType): + try: + # WARNING: not strictly decreasing; be careful not + # to add a direct conversion that goes satisfies + # mutable& with const& + return solve( + NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct + ) + except UnsatError: + pass + + # mutable & is satisfied with value + if isinstance(goal.type, MutRefCType): + try: + return solve(NamedCType(goal.name, goal.type.elem), direct=direct) + except UnsatError: + pass + + # TODO: These are referentially equal, shouldn't have to do this; + # ensuring we don't use type synonym IntArrayRef in codegen would + # help + if goal.type == ArrayRefCType(BaseCType(longT)): + return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct) + + if direct: + unsat(goal) + + # For now, all of these rules are mutually exclusive. + if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): + memory_format = direct_solve( + NamedCType( + SpecialArgName.possibly_redundant_memory_format, + OptionalCType(BaseCType(memoryFormatT)), + ) + ) + # No need to join "memory_format" and "options" if the target API takes "options" directly. + # Otherwise it will cause the redundant memory_format error. + if options_ctype in goal_ctypes: + return memory_format + try: + options = direct_solve(options_ctype) + return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" + except UnsatError: + return memory_format + elif goal == NamedCType("options", BaseCType(tensorOptionsT)): + dtype = direct_solve( + NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) + ) + pin_memory = direct_solve( + NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) + ) + device = direct_solve( + NamedCType("device", OptionalCType(BaseCType(deviceT))) + ) + layout = direct_solve( + NamedCType("layout", OptionalCType(BaseCType(layoutT))) + ) + return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" + + elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): + try: + options = direct_solve(options_ctype) + return f"c10::optTypeMetaToScalarType({options}.dtype_opt())" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.scalar_type()" + + elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): + try: + options = direct_solve(options_ctype) + return f"{options}.layout_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.layout()" + + elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): + try: + options = direct_solve(options_ctype) + return f"{options}.device_opt()" + except UnsatError: + out_tensor = direct_solve(out_tensor_ctype) + return f"{out_tensor}.device()" + + elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): + try: + options = direct_solve(options_ctype) + return f"{options}.pinned_memory_opt()" + except UnsatError: + # If we're calling a factory op from its out= variant, + # We don't actually care about the value of pin_memory. + out_tensor = direct_solve(out_tensor_ctype) + return "::std::nullopt" + + # We can always do translations from value types to reference types, like vector -> IntArrayRef + elif goal.type == BaseCType(intArrayRefT): + try: + return direct_solve(NamedCType(goal.name, longVec_ctype)) + except UnsatError: + # We can also go SymIntArrayRef -> IntArrayRef + symIntArrayRef_type = direct_solve( + NamedCType(goal.name, BaseCType(symIntArrayRefT)) + ) + return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" + elif goal.type == BaseCType(symIntArrayRefT): + try: + r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) + return f"c10::fromIntArrayRefSlow({r})" + except UnsatError: + return direct_solve(NamedCType(goal.name, longSymVec_ctype)) + elif goal.type == BaseCType(SymIntT): + return direct_solve(NamedCType(goal.name, BaseCType(longT))) + elif goal.type == OptionalCType(BaseCType(SymIntT)): + argname = direct_solve( + NamedCType(goal.name, OptionalCType(BaseCType(longT))) + ) + return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(longT): + symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) + return f"{symInt_type}.guard_int(__FILE__, __LINE__)" + elif goal.type == OptionalCType(BaseCType(longT)): + argname = direct_solve( + NamedCType(goal.name, OptionalCType(BaseCType(SymIntT))) + ) + return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt" + elif goal.type == BaseCType(optionalIntArrayRefT): + try: + return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) + except UnsatError: + argname = direct_solve( + NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT)) + ) + return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(optionalSymIntArrayRefT): + # TODO: You might also want to solve this from longSymVec_ctype or + # an optional version of it + argname = direct_solve( + NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) + ) + return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt" + elif goal.type == BaseCType(optionalScalarRefT): + return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) + elif goal.type == BaseCType(optionalTensorRefT): + return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) + + # Note [translation from C++ reference to value types] + # The below cases are all for when we have an argument with a reference type, + # and a corresponding goal with a value type. + # These are needed when we populate the inputs to a lambda capture and we need + # to guarantee the lifetime of each captured argument. + # We guard it with an explicit kwarg because converting to a value type is expensive + # (O(n)) to convert from IntArrayRef to vector), + # so the caller of translate() should be explicit that they need it. + if allow_expensive_conversions: + if goal.type == VectorCType(BaseCType(longT)): + intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) + argname = direct_solve(intArrayRef_ctype) + return f"{argname}.vec()" + if goal.type == VectorCType(BaseCType(SymIntT)): + symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) + argname = direct_solve(symIntArrayRef_ctype) + return f"{argname}.vec()" + elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): + optionalIntArrayRef_ctype = NamedCType( + goal.name, BaseCType(optionalIntArrayRefT) + ) + argname = direct_solve(optionalIntArrayRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt" + elif goal.type == OptionalCType(BaseCType(scalarT)): + optionalScalarRef_ctype = NamedCType( + goal.name, BaseCType(optionalScalarRefT) + ) + argname = direct_solve(optionalScalarRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" + elif goal.type == OptionalCType(BaseCType(scalarT)): + optionalTensorRef_ctype = NamedCType( + goal.name, BaseCType(optionalTensorRefT) + ) + argname = direct_solve(optionalTensorRef_ctype) + return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" + # Technically, we also need to handle cases of C++ containers holding reference types. + # But there currently aren't any ops that require lambda capture codegen + # With arguments like ::std::vector. + # If that changes, we'll have to add the translation here. + + # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor. + # We could probably generalize this to non-tensor types too. + if goal.type == MutRefCType(BaseCType(tensorT)): + const_ref_tensor_ctype = NamedCType( + goal.name, ConstRefCType(BaseCType(tensorT)) + ) + argname = direct_solve(const_ref_tensor_ctype) + return f"const_cast({argname})" + + unsat(goal) + + return [Expr(solve(g, direct=False), g) for g in goal_ctypes] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/ufunc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..17adcccecab563b6a4003215c778a00d5e1399c4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/ufunc.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torchgen.api.types as api_types +from torchgen.api import cpp, structured +from torchgen.api.types import ( + ArgName, + BaseCppType, + BaseCType, + Binding, + ConstRefCType, + CType, + NamedCType, + scalarT, +) +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + FunctionSchema, + NativeFunctionsGroup, + Type, +) + + +def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: + assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" + return f"ufunc_{func.name.name}_{dispatch_key}" + + +def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: + return schema_kernel_name(g.out.func, dispatch_key) + + +# Tensors are omitted (as they are stored in TensorIterator), everything else is +# passed along (technically, we can pass tensors along too, it just wastes +# argument registers) +# +# NB: used for CPU only +def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None: + # Dispatch stubs are always plain ints + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) + elif t == BaseType(BaseTy.Tensor): + return None + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def opmath_type(scalar_t: BaseCppType) -> BaseCppType: + if scalar_t == api_types.scalar_t: + return api_types.opmath_t + raise NotImplementedError + + +# NB: Tensors in constructor are stored in opmath_t, not scalar_t +# because Tensor in constructor = its a scalar tensor partially applied = +# it can be higher precision and we want to compute in that higher precision +# +# NB: CUDA only +def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(opmath_type(scalar_t))) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# Only Tensors ever get passed directly to operator() +# +# NB: CUDA only +# (Actually, this works for CPU too) +def ufunctor_apply_type( + t: Type, *, binds: ArgName, scalar_t: BaseCppType +) -> NamedCType: + if t == BaseType(BaseTy.Tensor): + return NamedCType(binds, BaseCType(scalar_t)) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +# The actual ufunc template function the user writes. Everything here +# is done in the computation type. compute_t is opmath_t in CUDA and scalar_t +# in CPU +def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: + r = cpp.valuetype_type(t, binds=binds, symint=False) + if r is not None: + return r + + if t == BaseType(BaseTy.Scalar): + return NamedCType(binds, compute_t) + elif t == BaseType(BaseTy.Tensor): + return NamedCType(binds, compute_t) + else: + raise AssertionError(f"unrecognized type {repr(t)}") + + +def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + + +def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: + return Binding( + nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), + name=a.name, + default=None, + argument=a, + ) + + +def ufunc_argument(a: Argument, compute_t: CType) -> Binding: + return Binding( + nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), + name=a.name, + default=None, + argument=a, + ) + + +@dataclass(frozen=True) +class UfunctorBindings: + ctor: list[Binding] + apply: list[Binding] + + +# ufunctors are a CUDA-only concept representing functors that take some of +# their arguments on a host-side constructor, and the rest in the device-side +# apply. E.g., +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers +# to the operator() definition +def ufunctor_arguments( + g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType +) -> UfunctorBindings: + ctor = [] + apply = [] + for a in g.functional.func.arguments.flat_non_out: + if a.type.is_tensor_like(): + if scalar_tensor_idx == 0: + # put it in the ctor anyway + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + scalar_tensor_idx = None + else: + if scalar_tensor_idx is not None: + scalar_tensor_idx -= 1 + apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) + else: + ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) + assert scalar_tensor_idx is None + return UfunctorBindings(ctor=ctor, apply=apply) + + +# ufuncs are the inner loop template functions that you wrote in ufunc/add.h +# which do the actual computation in question. E.g., +# +# template +# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ { +# return self + alpha * other; +# } +# +# In this file, we refer to T as compute_t which is bound by caller +def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]: + return [ + ufunc_argument(a, compute_t=compute_t) + for a in g.functional.func.arguments.flat_non_out + ] + + +# Stubs are the DispatchStub trampolines that CPU kernels use to get to their +# vectorized versions. E.g., +# +# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); +# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); +def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]: + # stubs drop all tensor arguments (they are implicit in the TensorIterator + # argument and keep everything else) + return [ + r + for a in g.out.func.arguments.flat_non_out + if not a.type.is_tensor_like() + for r in structured.argument(a) + ] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/unboxing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/unboxing.py new file mode 100644 index 0000000000000000000000000000000000000000..edb48ec5d172a7063b4003536506ed33f0f293fa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/unboxing.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.types import Binding, CppSignatureGroup, CType +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + ListType, + NativeFunction, + OptionalType, + Type, +) + + +# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the +# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is +# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the +# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register +# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase. +# +# Here's an example on how the codegen works: +# +# - Function Schema (source of truth) +# +# aten::empty.names(int[] size, *, Dimname[]? names, +# ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None, +# MemoryFormat? memory_format=None) -> Tensor +# - Argument Conversion +# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type. +# - int[] size +# ```cpp +# const c10::List size_list_in = (std::move(peek(stack, 0, 7))).toList(); +# +# std::vector size_vec; +# for (c10::IValue size_elem: size_list_in) { +# int64_t size_base = size_elem.to(); +# size_vec.push_back(size_base); +# } +# at::ArrayRef size_list_out(size_vec); +# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack. +# Will be passed to unboxed kernel. +# ``` +# - Dimname[]? names +# ```cpp +# ::std::optional names_opt = (std::move(peek(stack, 1, 7))).toOptional(); +# ::std::optional> names_opt_out; +# if (names_opt.has_value()) { +# ~~~~~~~~~~~ <-- Unwrapping optional shell +# const c10::IValue names_opt_in = names_opt.value(); +# const c10::List names_list_in = names_opt_in.toList(); +# +# std::vector names_vec; +# for (c10::IValue names_elem: names_list_in) { +# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one. +# at::Dimname names_base = names_elem.to(); +# names_vec.push_back(names_base); +# } +# at::ArrayRef names_list_out(names_vec); +# +# names_opt_out = ::std::optional>(names_list_out); +# } else { +# names_opt_out = ::std::optional>(); +# } +# ``` +# - ScalarType? dtype (similarly for the rest of the arguments) +# ```cpp +# ::std::optional dtype_opt = (std::move(peek(stack, 2, 7))).toOptional(); +# ::std::optional dtype_opt_out; +# if (dtype_opt.has_value()) { +# const c10::IValue dtype_opt_in = dtype_opt.value(); +# at::ScalarType dtype_base = dtype_opt_in.to(); +# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it +# directly using ".to()" API. +# dtype_opt_out = ::std::optional(dtype_base); +# } else { +# dtype_opt_out = ::std::optional(); +# } +# ``` +# +# - Unboxed Kernel Call +# ```cpp +# auto result_ = torch::empty( +# size_list_out, +# names_opt_out, +# options, +# memory_format_opt_out +# ); +# ``` +# +# - Push Result Back to Stack +# ```cpp +# drop(stack, 7); +# pack(stack, std::move(result_)); +# ``` +connector = "\n\t" + + +# Return unboxing function name for a NativeFunction +def name(f: NativeFunction) -> str: + return f.func.name.unambiguous_name() + + +# Convert all the arguments in a NativeFunction to C++ code +def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]: + # we need the 'self' argument so method needs to be False + args = ( + CppSignatureGroup.from_native_function(f, method=False) + .most_faithful_signature() + .arguments() + ) + code_list = [ + f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));" + for i in range(len(args)) + ] + [""] + binding_list = [] + for arg in args: + # expecting only Argument + if not isinstance(arg.argument, Argument): + raise Exception( # noqa: TRY002 + f"Unexpected argument type, expecting `Argument` but got {arg}" + ) + argument: Argument = arg.argument + unboxed_name, _, code, decl = argumenttype_ivalue_convert( + argument.type, + argument.name, + mutable=argument.is_write, + ) + code_list.extend(decl) + code_list.extend(code) + binding_list.append(arg.with_name(unboxed_name)) + return binding_list, code_list + + +# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: +# (1) the C++ code necessary to unbox the argument +# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType +def argumenttype_ivalue_convert( + t: Type, arg_name: str, *, mutable: bool = False +) -> tuple[str, CType, list[str], list[str]]: + # Unboxing is for mobile, which doesn't care about SymInts + ctype = cpp.argumenttype_type( + t=t, mutable=mutable, binds=arg_name, symint=False + ).type + + if isinstance(t, BaseType): + out_name = f"{arg_name}_base" + code, decl = _gen_code_base_type( + arg_name=arg_name, out_name=out_name, ctype=ctype + ) + elif isinstance(t, OptionalType): + out_name = f"{arg_name}_opt_out" + code, decl = _gen_code_optional_type( + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, + ) + elif isinstance(t, ListType): + out_name = f"{arg_name}_list_out" + code, decl = _gen_code_list_type( + arg_name=arg_name, + out_name=out_name, + t=t, + ctype=ctype, + ) + else: + raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") # noqa: TRY002 + return out_name, ctype, code, decl + + +def _gen_code_base_type( + arg_name: str, out_name: str, ctype: CType +) -> tuple[list[str], list[str]]: + return [ + f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" + ], [] + + +def _gen_code_optional_type( + arg_name: str, out_name: str, t: OptionalType, ctype: CType +) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_opt_in" + res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name) + return ( + f""" +auto {arg_name}_opt = {arg_name}.toOptional(); +{ctype.cpp_type(strip_ref=True)} {out_name}; +if ({arg_name}_opt.has_value()) {{ + const c10::IValue {in_name} = {arg_name}_opt.value(); + {connector.join(res_code)} + {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name}); +}} else {{ + {out_name} = {ctype.cpp_type(strip_ref=True)}(); +}} + """.split("\n"), + decl, + ) + + +def _gen_code_list_type( + arg_name: str, out_name: str, t: ListType, ctype: CType +) -> tuple[list[str], list[str]]: + in_name = f"{arg_name}_list_in" + elem_name = f"{arg_name}_elem" + code = [f"const c10::List {in_name} = {arg_name}.toList();"] + res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name) + # handle list type with size, e.g., bool[4] + if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size: + code.extend( + f""" +{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name}); + """.split("\n") + ) + # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> + elif isinstance(t.elem, OptionalType): + code.extend( + f""" +{ctype.cpp_type(strip_ref=True)} {out_name}; +for (c10::IValue {elem_name}: {in_name}) {{ + {connector.join(res_code)} + {out_name}.push_back({res_name}); +}} + """.split("\n") + ) + else: + # use ArrayRef as default. + vec_name = arg_name + "_vec" + # need to bring vector instantiation out of scope so that ArrayRef has valid data + decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};") + code.extend( + f""" +for (c10::IValue {elem_name}: {in_name}) {{ + {connector.join(res_code)} + {vec_name}.push_back({res_name}); +}} +{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); + """.split("\n") + ) + return code, decl diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f08a743ae2dc766530fd8f93be9ebb8b7733f21 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__init__.py @@ -0,0 +1,19 @@ +from torchgen.dest.lazy_ir import ( + generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes, + GenLazyIR as GenLazyIR, + GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition, + GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition, +) +from torchgen.dest.native_functions import ( + compute_native_function_declaration as compute_native_function_declaration, +) +from torchgen.dest.register_dispatch_key import ( + gen_registration_headers as gen_registration_headers, + gen_registration_helpers as gen_registration_helpers, + RegisterDispatchKey as RegisterDispatchKey, +) +from torchgen.dest.ufunc import ( + compute_ufunc_cpu as compute_ufunc_cpu, + compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel, + compute_ufunc_cuda as compute_ufunc_cuda, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/lazy_ir.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/lazy_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..b912b8f2427f8848b1a65736f9b36b71b85c06ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/lazy_ir.py @@ -0,0 +1,707 @@ +from __future__ import annotations + +import itertools +from abc import ABC +from dataclasses import dataclass +from typing import Any + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.lazy import ( + getValueT, + isValueType, + LazyArgument, + LazyIrProperties, + LazyIrSchema, + tensorListValueT, +) +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + deviceT, + DispatcherSignature, + kernel_signature, + NativeSignature, + OptionalCType, + VectorCType, +) +from torchgen.context import method_with_native_function +from torchgen.dest.lazy_ts_lowering import ts_lowering_body +from torchgen.model import ( + Argument, + BackendIndex, + BackendMetadata, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsGroup, +) + + +def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: + """ + Given a LazyArgument, + generate a c++ string for materializing an rvalue of that arg for passing into + a lazy Node constructor. + """ + + # TODO: Matching on CType seems wrong; should be matching on Type + if isValueType(arg.lazy_type): + if isinstance(arg.lazy_type, BaseCType): + if arg.is_wrapped_scalar: + return f"node_{arg.name}" + elif arg.lazy_type.type is tensorListValueT: + return f"lazy_{arg.name}_tensorlist" + elif arg.is_symint_or_list: + return f"GetSymIntValue({arg.name})" + return f"lazy_{arg.name}->GetIrValue()" + elif isinstance(arg.lazy_type, OptionalCType): + if arg.is_symint_or_list: + # TODO: I don't understand when you should put lazy_ in the name + # or not + return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt" + elif arg.is_wrapped_scalar: + return f"node_{arg.name}" + return ( + f"lazy_{arg.name} ? " + f"std::make_optional(lazy_{arg.name}->GetIrValue()) : " + "::std::nullopt" + ) + else: + raise AssertionError( + f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" + ) + else: + # NB: this is here because right now we aren't treating SymInt[] as a + # value type; when we do this needs to move above + # NB: we cannot test arg.lazy_type as we've already specified it is an + # int64_t and so we cannot distinguish between SymInt and int64_t + if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType( + BaseTy.SymInt + ): + if arg.symint: + return f"GetSymIntArrayRefValue({arg.name})" + else: + return f"std::vector({arg.name}.begin(), {arg.name}.end())" + elif isinstance(arg.lazy_type, VectorCType) and isinstance( + arg.lazy_type.elem, BaseCType + ): + return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())" + elif ( + isinstance(arg.lazy_type, OptionalCType) + and isinstance(arg.lazy_type.elem, VectorCType) + and isinstance(arg.lazy_type.elem.elem, BaseCType) + ): + return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})" + else: + return f"{arg.name}" + + +def node_ctor_inputs(schema: LazyIrSchema) -> str: + """ + Produce a formatted string with the arguments as passed into the constructor of a node class. + """ + node_ctor_values = [ + node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args() + ] + return ", ".join(node_ctor_values) + + +def gen_fallback_code( + schema: LazyIrSchema, + sig: DispatcherSignature | NativeSignature, + overload_name: str, +) -> str: + """ + Generate code that falls back to eager conditioned on a predicate + """ + dispatcher_sig = DispatcherSignature.from_schema(schema.func) + exprs = translate(sig.arguments(), dispatcher_sig.arguments()) + fallback_args = ",\n ".join([a.expr for a in exprs]) + if len(overload_name): + aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})" + else: + aten_op_str = f"ATEN_OP({schema.aten_name})" + return f""" + if (force_eager_fallback({aten_symbol(schema)})) {{ + return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call( + {fallback_args} + ); + }} +""" + + +def aten_symbol(schema: LazyIrSchema) -> str: + missing_interned_strings = { + "sigmoid_backward", + } + if schema.aten_name in missing_interned_strings: + return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")' + + if not schema.aten_name.startswith("at::"): + return f"at::aten::{schema.aten_name}" + else: + return schema.aten_name + + +# converts all tensor-like arguments to meta tensors. Returns: +# (1) a string containing all of the logic that does the conversions. +# (2) a context, to be used by translate(), with all of the relevant bindings. +def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]: + context: list[Binding] = [] + unwrapped_tensor_args: list[str] = [] + for arg in sig.arguments(): + if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): + unwrapped_name = f"{arg.name}_meta" + unwrapped_tensor_args.append( + f"auto {unwrapped_name} = to_meta({arg.name});" + ) + context.append(arg.with_name(unwrapped_name)) + else: + context.append(arg) + unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) + return unwrap_tensor_args_str, context + + +@dataclass(frozen=True) +class GenLazyIR(ABC): + backend_index: BackendIndex + backend_name: str + node_base: str + use_lazy_shape: bool + + @method_with_native_function + def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: + func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func + metadata = self.backend_index.get_kernel( + f.functional if isinstance(f, NativeFunctionsGroup) else f + ) + schema = LazyIrSchema( + func, symint=metadata is not None and metadata.supports_symint() + ) + return self.gen(schema) + + # there is no lowering functionality generated unless this IR base class is subclassed and + # implemented as a backend-specific node + def lowering_function(self, schema: LazyIrSchema) -> str: + return "" + + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + return "" + + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + return f"""bool CanBeReused({node_ctor_args}) const {{ + return false; + }}""" + + def node_base_ctor_call(self, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + # backends can customize the way the node base class constructor is called, + # as long as all of its arguments can be generated from information available from the schema + base_ctor_value_args_list = [] + for arg in value_args: + if isinstance(arg.lazy_type, (BaseCType, VectorCType)): + base_ctor_value_args_list.append(f"{arg.name}") + elif isinstance(arg.lazy_type, OptionalCType): + base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)") + else: + raise AssertionError( + f"Unsupported type ({arg.lazy_type}) - add support if necessary" + ) + base_ctor_value_args = ", ".join(base_ctor_value_args_list) + + scalar_args = schema.filtered_args(values=False, scalars=True) + + # Shape construction. + # Conditionally build shape depending on specified shape property + if schema.properties.ShapePrecompute: + shape_ctor_arg = "std::move(shapes)," + elif schema.properties.ShapeCompute: + shape_args = [a.name for a in value_args] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)})," + elif schema.properties.ShapeCache: + shape_args = [f"operand({i})" for i in range(len(value_args))] + shape_args.extend(a.name for a in scalar_args) + shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }}," + else: + shape_ctor_arg = "" + + scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args) + + return f"""{self.node_base}( + {schema.node_name}::ClassOpKind(), + OpList{{{base_ctor_value_args}}}, + {shape_ctor_arg} + /* num_outputs */ {len(schema.returns)}, + torch::lazy::MHash({scalar_hashes}))""" + + def gen(self, schema: LazyIrSchema) -> list[str]: + opkind = schema.opkind or aten_symbol(schema) + + # for now, we just want one IR class decl and soon after also the method defs + # and we use the functional version not out/inplace. + all_args = schema.filtered_args() + scalar_args = schema.filtered_args(values=False, scalars=True) + + ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args] + reuse_ctor_args = ", ".join(ctor_args) + if self.use_lazy_shape and schema.properties.ShapePrecompute: + ctor_args.append("std::vector&& shapes") + node_ctor_args = ", ".join(ctor_args) + + scalar_initializers = ",\n ".join( + [ + # This code is just special casing the mapping from string_view -> strings + f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)" + if a.lazy_type.cpp_type() == "::std::optional" + else f"{a.name}({a.name})" + for a in scalar_args + ] + ) + if len(scalar_initializers): + scalar_initializers = f",\n {scalar_initializers}" + scalar_decls = "\n ".join( + [ + f"std::string {a.name};" + if a.lazy_type.cpp_type() == "c10::string_view" + else f"::std::optional {a.name};" + if a.lazy_type.cpp_type() == "::std::optional" + else f"{a.lazy_type.cpp_type()} {a.name};" + for a in scalar_args + ] + ) + optional_values = [ + arg.name + for arg in schema.filtered_args(values=True, scalars=False) + if isinstance(arg.lazy_type, OptionalCType) + ] + has_optional_decls = "\n ".join( + [f"bool has_{value}: 1;" for value in optional_values] + ) + has_optional_defs = "\n ".join( + [f"has_{value} = !!{value};" for value in optional_values] + ) + members_to_string = [] + for arg in scalar_args: + if isinstance(arg.lazy_type, OptionalCType): + value = f"{arg.name}.value()" + if arg.is_generator: + value = '"torch.Generator()"' + members_to_string.append( + f"""if ({arg.name}.has_value()) {{ + ss << ", {arg.name}=" << {value}; + }} else {{ + ss << ", {arg.name}=null"; + }}""" + ) + else: + members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};') + members_to_string_str = "\n ".join(members_to_string) + + return [ + f"""\ +class {schema.node_name} : public {self.node_base} {{ + public: + static torch::lazy::OpKind ClassOpKind() {{ + return torch::lazy::OpKind({opkind}); + }} + + {schema.node_name}({node_ctor_args}) + : {self.node_base_ctor_call(schema)}{scalar_initializers} + {{ + {has_optional_defs} + }} + + std::string ToString() const override {{ + std::stringstream ss; + ss << {self.node_base}::ToString(); + {members_to_string_str} + return ss.str(); + }} + + {self.create_function(schema, reuse_ctor_args)} + + {self.can_be_reused_function(schema, reuse_ctor_args)} + + {self.lowering_function(schema)} + + {scalar_decls} + {has_optional_decls} + +}}; + +""", + ] + + +@dataclass(frozen=True) +class GenTSLazyIR(GenLazyIR): + def lowering_function(self, schema: LazyIrSchema) -> str: + signature = """ + torch::lazy::TSOpVector Lower( + std::shared_ptr function, + torch::lazy::TSLoweringContext* loctx) const override""" + + if schema.properties.LowerDeclOnly: + return f"{signature};" + elif schema.properties.Lower: + return f"""{signature} {{ + {ts_lowering_body(schema)} + }} + """ + else: + return "" + + def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"static NodePtr Create({node_ctor_args})" + if schema.properties.CreateFnDeclOnly: + return f"{signature};" + elif not schema.properties.CreateFn: + return "" + return f"""{signature} {{ + return ReuseOrMakeNode<{schema.node_name}>(data); + }}""" + + def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: + signature = f"bool CanBeReused({node_ctor_args}) const" + if schema.properties.CanBeReusedDeclOnly: + return f"{signature};" + elif not schema.properties.CanBeReused: + return "" + value_comparison = [] + for arg in itertools.chain(schema.positional_values, schema.keyword_values): + if isinstance(arg.lazy_type, OptionalCType): + value_comparison.append( + f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)" + ) + else: + value_comparison.append(f"operand(i++) == {arg.name}") + for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars): + if isinstance(arg.lazy_type, OptionalCType): + value_comparison.append( + f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))" + ) + else: + value_comparison.append(f"this->{arg.name} == {arg.name}") + value_comparison_str = " &&\n ".join(value_comparison) + + return f"""{signature} {{ + size_t i = 0; + return ({value_comparison_str}); + }}""" + + +@dataclass(frozen=True) +class GenLazyNativeFuncDefinition: + class_method_name: str + backend_index: BackendIndex + tensor_class: str + gen_forced_fallback_code: bool + backend_namespace: str + get_tensorlist: str + get_tensor_or_wrap_number: str + try_get_tensor: str + metrics_counter: str + create_tensor: str + create_from_first_tensor: bool + create_aten_from_ltc_tensor: str + tuple_aten_from_ltc_tensors: str + lazy_tensor_ptr: str + get_device_fn: str + + def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + # Generates lazy_{name} variables for LazyTensors wrapping input tensors + lazy_tensor_decls: list[str] = [] + for arg in value_args: + if arg.is_wrapped_scalar: + if isinstance(arg.lazy_type, OptionalCType): + lazy_tensor_decls.append( + f"""auto node_{arg.name} = {arg.name} ? + std::make_optional(torch::lazy::LazyGraphExecutor::Get()-> + GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)): + ::std::nullopt;""" + ) + else: + lazy_tensor_decls.append( + f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()-> + GetIrValueForScalarFromCodegen({arg.name}, *common_device);""" + ) + elif arg.is_symint_or_list: + continue # values are extracted in isValueType + elif isinstance(arg.lazy_type, BaseCType): + if arg.lazy_type.type is tensorListValueT: + lazy_tensor_decls.append( + f"auto lazy_{arg.name}_tensorlist = " + f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});" + ) + else: + lazy_tensor_decls.append( + f"{self.lazy_tensor_ptr} lazy_{arg.name} = " + f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);" + ) + elif isinstance(arg.lazy_type, OptionalCType): + assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem + # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it + # until we encounter a real world example. + lazy_tensor_decls.append( + f"{self.lazy_tensor_ptr} lazy_{arg.name} = " + f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));" + ) + else: + raise AssertionError( + f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" + ) + return ("\n ").join(lazy_tensor_decls) + + def force_eager_fallback( + self, + func: NativeFunction, + schema: LazyIrSchema, + metadata: BackendMetadata, + sig: DispatcherSignature | NativeSignature, + ) -> str: + if self.gen_forced_fallback_code: + return gen_fallback_code( + schema, sig, overload_name=func.func.name.overload_name + ) + return "" + + def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str: + return f"{self.metrics_counter};" + + def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str: + value_args = schema.filtered_args(values=True, scalars=False) + scalar_args = schema.filtered_args(values=False, scalars=True) + value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] + optional_device = OptionalCType(BaseCType(deviceT)) + optional_devices = [ + a.name for a in scalar_args if a.lazy_type == optional_device + ] + assert len(value_types_names) > 0 or len(optional_devices) > 0, ( + "Expected at least one Value or Device type" + ) + get_device_str = ( + f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})" + ) + return f"""auto common_device = {get_device_str}; + TORCH_INTERNAL_ASSERT(common_device); + """ + + def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str: + metadata = self.backend_index.get_kernel(func) + assert metadata is not None + all_args = schema.filtered_args() + returns_length = len(schema.returns) + # call the meta kernel if it exists, to compute output shape/dtype for our IR + # Note [Generated LTC Shape Functions] + # LTC uses meta tensors from core to do shape inference when possible, and otherwise + # we generate a shape function declaration that needs to be manually implemented. + # How do we detect which ops are eligible to use meta tensors? + # In general we should be able to use meta tensors not just on structured operators, + # but also on composite operators that are implemented in terms of structured kernels. + # We don't currently have a way of knowing at codegen time which ops are implemented that way. + # This is the case for all view and view_copy operators however, so we're going to + # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them). + is_view_copy_op = "view_copy" in func.tags + is_structured = func.structured or func.structured_delegate is not None + if is_structured or is_view_copy_op: + meta_out = """ +std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" + if returns_length > 1: + + def this_shape(i: int) -> str: + return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" + + shapes_str = ",".join([this_shape(i) for i in range(returns_length)]) + meta_out = "std::vector shapes{" + shapes_str + "};" + + # Convert tensor args to the meta device and call it. + # (We can't pass in the input tensors directly, because they are "functional wrappers". + # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.) + # Even at::meta:: functions might redispatch, e.g. if they call into view ops. + dispatcher_sig = DispatcherSignature.from_schema(func.func) + meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) + meta_call_args = [ + e.expr + for e in translate( + meta_call_ctx, dispatcher_sig.arguments(), method=False + ) + ] + if is_view_copy_op: + # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel + assert func.has_composite_explicit_autograd_non_functional_kernel + dispatch_ns = "compositeexplicitautogradnonfunctional" + else: + dispatch_ns = "meta" + aten_name = schema.aten_name + # TODO: this is trolling + if func.func.has_symint() and metadata.supports_symint(): + aten_name += "_symint" + shape_str = f"""\ + {meta_conversion_str} + auto out_meta = at::{dispatch_ns}::{aten_name}({", ".join(meta_call_args)}); + {meta_out}""" + else: + shape_sig = ComputeShapeSignature( + metadata.kernel, func, symint=metadata.supports_symint() + ) + shape_str = f""" + auto shapes = {shape_sig.shape_call};""" + + shape_str += f""" + TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" + + # Calculating which dimensions are symbolic + func_schema_str = "aten::" + str(func.func) + shape_str += f""" + if(torch::lazy::symbolicShapeEnabled()){{ + std::vector inputs = {{ {", ".join(str(a.name) for a in all_args)} }}; + const char* schema_str = "{func_schema_str}"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + }} + """ + return shape_str + + def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str: + node_ctor_input_str = node_ctor_inputs(schema) + return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str}); + if (!node) {{ + {self.shape_inference(func, schema)} + node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes)); + CacheNode(node); + }} + """ + + def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str: + # xla uses an instance method for tensor creation, for the time being + if self.create_from_first_tensor: + # TODO(whc) remove this if XLA switches to using static method for creation + assert first_tensor_name is not None, ( + "Requires first tensor to create lazy tensor" + ) + return f"{first_tensor_name}.{self.create_tensor}" + return f"{self.backend_namespace}::{self.create_tensor}" + + def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str: + returns_length = len(schema.returns) + value_args = schema.filtered_args(values=True, scalars=False) + value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] + first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None + bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}( + {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));""" + + if returns_length > 1: + assert len(value_types_names) > 0, ( + "Code below assumes there is at least one tensor arg" + ) + bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors; + for (int i = 0; i < {returns_length}; i++) {{ + lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device)); + }} + auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);""" + + if schema.name.name.inplace or func.func.is_out_fn(): + assert returns_length == 1, ( + "We assumed there was no such case where an op is an in-place variant " + f"and has tuple outputs, but got tuple of len {returns_length}." + ) + bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node); + auto& result = {first_tensor_name};""" + + bridge_str += """ + return result;""" + return bridge_str + + @method_with_native_function + def __call__(self, func: NativeFunction) -> list[str]: + sig = kernel_signature(func, self.backend_index) + metadata = self.backend_index.get_kernel(func) + assert metadata is not None + schema = LazyIrSchema(func.func, symint=metadata.supports_symint()) + return [ + f"""\ + {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{ + {self.force_eager_fallback(func, schema, metadata, sig)} + {self.metrics(func, schema)} + {self.get_device(func, schema)} + {self.lazy_tensor_decls(func, schema)} + {self.build_ir_node(func, schema)} + {self.return_aten_tensor(func, schema)} + }}\n + """ + ] + + +class ComputeShapeSignature: + """ + Here we use the base name as the suffix of the signature to avoid generating for in-place variants. + """ + + def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None: + self.__schema = LazyIrSchema(f.func, symint=symint) + self.__dispatch_args = ", ".join( + [a.decl() for a in dispatcher.arguments(f.func, symint=symint)] + ) + self.__call_args = ", ".join( + [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)] + ) + self.__kernel_name = kernel_name + + def __decl_suffix(self) -> str: + return f"{self.__kernel_name}({self.__dispatch_args})" + + def __call_suffix(self) -> str: + return f"{self.__kernel_name}({self.__call_args})" + + @property + def shape_decl(self) -> str: + return f"TORCH_API std::vector compute_shape_{self.__decl_suffix()}" + + @property + def shape_call(self) -> str: + return f"torch::lazy::compute_shape_{self.__call_suffix()}" + + +@dataclass(frozen=True) +class GenLazyShapeInferenceDefinition: + backend_index: BackendIndex + tensor_class: str + + @method_with_native_function + def __call__(self, f: NativeFunction) -> list[str]: + metadata = self.backend_index.get_kernel(f) + assert metadata is not None + + # See Note [Generated LTC Shape Functions] + is_view_copy_op = "view_copy" in f.tags + is_structured = f.structured or f.structured_delegate is not None + if is_structured or is_view_copy_op: + return [] + else: + shape_sig = ComputeShapeSignature( + metadata.kernel, f, symint=metadata.supports_symint() + ) + return ["\n".join([f"{shape_sig.shape_decl};"])] + + +def generate_non_native_lazy_ir_nodes( + non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR +) -> list[str]: + """Generate the non-native lazy IR node classes""" + nodes = [] + for op in non_native: + # Set default properties for Non-Native IRs + properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly") + for p in op.get("properties", []): + setattr(properties, p, True) + + # non-native is assumed to want symint bindings if you wrote symint + schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True) + schema.opkind = op.get("opkind") + nodes.append(gen_lazy_ir.gen(schema)[0]) + + return nodes diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/lazy_ts_lowering.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/lazy_ts_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..70161216d8e7c95e194b0d89b345e0da886ef989 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/lazy_ts_lowering.py @@ -0,0 +1,48 @@ +from torchgen.api.lazy import LazyArgument, LazyIrSchema +from torchgen.api.types import OptionalCType + + +def ts_lowering_body(schema: LazyIrSchema) -> str: + # for now, we just want one IR class decl and soon after also the method defs + # and we use the functional version not out/inplace. + emplace_arguments = [] + + def get_value(arg: LazyArgument) -> str: + if isinstance(arg.lazy_type, OptionalCType): + return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr" + return "loctx->GetOutputOp(operand(i++))" + + for arg in schema.positional_args: + if arg.is_lazy_value: + emplace_arguments.append(get_value(arg)) + continue + emplace_arguments.append(f'"{arg.name}", {arg.name}') + + emplace_arguments_str = "\n ".join( + [f"arguments.emplace_back({a});" for a in emplace_arguments] + ) + emplace_kwarg_values = [ + f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values + ] + emplace_kwarg_scalars = [ + f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars + ] + emplace_kwarguments = "\n ".join( + [ + f"kwarguments.emplace_back({a});" + for a in emplace_kwarg_values + emplace_kwarg_scalars + ] + ) + return f"""\ + std::vector arguments; + std::vector kwarguments; + arguments.reserve({len(emplace_arguments)}); + kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); + size_t i = 0; + {emplace_arguments_str} + {emplace_kwarguments} + torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments); + TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); + + return {schema.aten_name}_out; +""" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/native_functions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/native_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..05e252d09f9c16888dec66045a92b8aefa19b667 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/native_functions.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.types import kernel_signature +from torchgen.context import with_native_function_and_index +from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup +from torchgen.utils import mapMaybe + + +def torch_api_key_word_prefix(bankend_index: BackendIndex) -> str: + if bankend_index.external: + return "" + + # Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structured + # kernels. Regarding these produced structured kernels, they should be visible for the Intel GPU ATen + # library. Therefore, we need to add "TORCH_XPU_API" prefix to these structured kernels, + # rather than "TORCH_API". Because the semantic of "TORCH_API" is "hidden" for out-of-tree backends. + # For other in-tree backends like cpu and cuda, they still use "TORCH_API" prefix with "visible" semantic. + device_torch_api_key_word_mapping = { + "XPU": "TORCH_XPU_API", + } + + return ( + device_torch_api_key_word_mapping.get( + bankend_index.dispatch_key.name, "TORCH_API" + ) + + " " + ) + + +@with_native_function_and_index +def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None: + sig = kernel_signature(f, backend_index) + metadata = backend_index.get_kernel(f) + if metadata is None: + return None + if "legacy::" in metadata.kernel: + return None + else: + prefix = "static" if backend_index.external else "TORCH_API" + return f"{prefix} {sig.decl(name=metadata.kernel)};" + + +@with_native_function_and_index +def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]: + meta_name = meta.name(g) + out_args = structured.impl_arguments(g) + metadata = backend_index.get_kernel(g) + if metadata is None: + return [] + prefix = torch_api_key_word_prefix(backend_index) + return [ + f"""\ +struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ +void impl({", ".join(a.decl() for a in out_args)}); +}}; +""" + ] + + +# Generates NativeFunctions.h, a list of forward declarations of all +# actual kernel definitions we keep in aten/src/ATen/native/ +@with_native_function_and_index +def compute_native_function_declaration( + g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex +) -> list[str]: + metadata = backend_index.get_kernel(g) + if isinstance(g, NativeFunctionsGroup): + if metadata is not None and metadata.structured: + if backend_index.external: + # Structured hasn't been tested with external backends yet. + raise AssertionError( + "Structured external backend functions are not implemented yet." + ) + else: + return gen_structured(g, backend_index) + else: + return list( + mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions()) + ) + else: + x = gen_unstructured(g, backend_index) + return [] if x is None else [x] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/register_dispatch_key.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/register_dispatch_key.py new file mode 100644 index 0000000000000000000000000000000000000000..52bb9602a73f050301e7f4953364d242e2722e54 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/register_dispatch_key.py @@ -0,0 +1,1016 @@ +from __future__ import annotations + +import itertools +import textwrap +from dataclasses import dataclass +from typing import Literal, TYPE_CHECKING +from typing_extensions import assert_never + +import torchgen.api.cpp as cpp +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + ConstRefCType, + CppSignature, + CppSignatureGroup, + DispatcherSignature, + Expr, + kernel_signature, + MutRefCType, + NamedCType, + NativeSignature, + tensorT, +) +from torchgen.context import method_with_native_function, native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + DeviceCheckType, + DispatchKey, + gets_generated_out_inplace_wrapper, + is_cuda_dispatch_key, + NativeFunction, + NativeFunctionsGroup, + SchemaKind, + TensorOptionsArguments, +) +from torchgen.utils import mapMaybe, Target + + +if TYPE_CHECKING: + from torchgen.selective_build.selector import SelectiveBuilder + + +def gen_registration_headers( + backend_index: BackendIndex, + per_operator_headers: bool, + rocm: bool, +) -> list[str]: + if per_operator_headers: + headers = ["#include "] + else: + headers = ["#include "] + + if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta): + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.CUDA: + if rocm: + headers.append("#include ") + else: + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.MPS: + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.XPU: + # XPU specific, this header resides in third_party/torch-xpu-ops + headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.MTIA: + headers.append("#include ") + elif per_operator_headers: + headers += [ + "#include ", + "#include ", + "#include ", + "#include ", + ] + else: + headers.append("#include ") + + headers.append("#include ") + return headers + + +def gen_empty_impl_names( + backend_index: BackendIndex, +) -> tuple[str | None, str | None]: + empty_impl = None + empty_strided_impl = None + + if backend_index.dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.XPU, + DispatchKey.MTIA, + ): + dispatch = str(backend_index.dispatch_key).lower() + empty_impl = f"at::detail::empty_{dispatch}" + empty_strided_impl = f"at::detail::empty_strided_{dispatch}" + elif backend_index.dispatch_key in ( + DispatchKey.CompositeExplicitAutogradNonFunctional, + DispatchKey.QuantizedCPU, + DispatchKey.QuantizedCUDA, + DispatchKey.XPU, + ): + empty_impl = "at::empty" + empty_strided_impl = "at::empty_strided" + + return empty_impl, empty_strided_impl + + +def gen_create_out_helper(backend_index: BackendIndex) -> list[str]: + if backend_index.dispatch_key == DispatchKey.Meta: + empty_options = "options.device(at::kMeta)" + else: + empty_options = "options" + + empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index) + if empty_impl is None: + return [] + + return [ + f""" +Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ + if (strides.empty()) {{ + return {empty_impl}(sizes, {empty_options}); + }} else {{ + return {empty_strided_impl}(sizes, strides, {empty_options}); + }} +}} +""" + ] + + +def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]: + _, empty_strided_impl = gen_empty_impl_names(backend_index) + return ( + [] + if empty_strided_impl is None + else [ + f""" +std::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ + if (out.strides() != strides) {{ + return {empty_strided_impl}(sizes, strides, options); + }} + return std::nullopt; +}} +""" + ] + ) + + +def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]: + if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: + # The function isn't used by this key (since only functional ops have a kernel for this key), + # so we need to not include it to avoid a defined-but-not-used error. + return [] + return [ + """ +void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { + TORCH_CHECK(options.dtype() == out.dtype(), + "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); + TORCH_CHECK(options.device() == out.device(), + "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); + const bool resized = at::native::resize_output(out, sizes); + // Only restride if a resize occurred; otherwise we ignore the (advisory) + // strides from the meta function and directly use the output tensor's + // preexisting strides + if (resized) { + if (!strides.empty()) { + TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); + // TODO: avoid the redispatch here + out.as_strided_(sizes, strides); + } else if (options.memory_format_opt().has_value()) { + out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); + } + } +} +""" + ] + + +def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]: + return [ + """ +void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { + // These checks are needed on those operators that: + // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm') + // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod') + // For other operators (e.g. 'add'), 'TensorIterator' already checks + // these things separately. + TORCH_CHECK(options.dtype() == self.dtype(), + "Bad in-place call: ", + "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match"); + TORCH_CHECK(options.device() == self.device(), + "Bad in-place call: ", + "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match"); + TORCH_CHECK(sizes == self.sizes(), + "Bad in-place call: ", + "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match"); +} +""" + ] + + +def gen_registration_helpers(backend_index: BackendIndex) -> list[str]: + return [ + 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")', + *gen_create_out_helper(backend_index), + *gen_resize_out_helper(backend_index), + *gen_check_inplace_helper(backend_index), + *gen_maybe_create_proxy_helper(backend_index), + "C10_DIAGNOSTIC_POP()", + ] + + +# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). +# +# - The primary function of this file is to register all of the +# implementations for the given dispatch key to the dispatcher, +# so they are available for use in PyTorch. If dispatch is +# None, we generate schema (def) registrations and catchall +# registrations. +# - The secondary function of this file is to generate a wrapper +# around functions. In CPUType these wrappers do nothing +# (and should be removed), but in other cases they handle +# DeviceGuard. A small extra benefit of wrappers is they +# are not overloaded, so they can be used in the registration +# API without having to disambiguate which overload you want +# (as would be the case if you directly registered native:: +# functions). +# - The tertiary function of this file is to generate *static* +# cpp API bindings which can be used to bypass dispatcher +# directly to kernels, but with user-friendly cpp-style API +@dataclass(frozen=True) +class RegisterDispatchKey: + backend_index: BackendIndex + + target: Literal[ + Target.ANONYMOUS_DEFINITION, + Target.NAMESPACED_DEFINITION, + Target.NAMESPACED_DECLARATION, + Target.REGISTRATION, + ] + + # Selector object to determine which operators to generate + # registration code for. + selector: SelectiveBuilder + + # Whether or not we are actually code-genning for ROCm + rocm: bool + + # Whether or not to generate symint registrations or not. External users + # of codegen who don't care about symints can set this to false to get + # non-SymInt codegen + symint: bool + + # The class that all unstructured native functions live under. This is used to improve + # compiler error messages when a kernel writer adds a native function with the wrong signature. + # This is only used in unstructured kernels, since structured kernels already live in a class. + # Finally, this field is currently Optional because it is only used by external backends. + # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating + # all of the existing kernel signatures scattered across aten/src/ATen/native. + class_method_name: str | None + + # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering + # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher. + skip_dispatcher_op_registration: bool + + @staticmethod + def gen_device_check( + type: DeviceCheckType, args: list[Argument], method_name: str + ) -> str: + if type == DeviceCheckType.NoCheck: + return " // No device check\n" + + device_check = "std::optional common_device = std::nullopt;\n" + device_check += "(void)common_device; // Suppress unused variable warning\n" + for arg in args: + # Only tensor like arguments are eligible + if arg.type.is_tensor_like(): + device_check += f""" + c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");""" + return device_check + + @method_with_native_function + def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: + if isinstance(f, NativeFunctionsGroup): + g: NativeFunctionsGroup = f + # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. + # gen_structured() has special logic to handle auto-generated kernels. + if g.structured: + return self.gen_structured(g) + else: + return list( + mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()) + ) + elif isinstance(f, NativeFunction): + r = self.gen_unstructured(f) + return [] if r is None else [r] + else: + assert_never(f) + + def wrapper_kernel_sig( + self, f: NativeFunction + ) -> NativeSignature | DispatcherSignature: + # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. + return DispatcherSignature.from_schema( + f.func, + prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_", + symint=self.symint, + ) + + def gen_out_inplace_wrapper( + self, f: NativeFunction, g: NativeFunctionsGroup | None + ) -> str | None: + if g is None: + return None + k = f.func.kind() + if k is SchemaKind.inplace: + copy_op = "at::_copy_from" + elif k is SchemaKind.out: + copy_op = "at::_copy_from_and_resize" + else: + raise AssertionError("gen_out_inplace_wrapper called on a functional op") + + sig = self.wrapper_kernel_sig(f) + name = sig.name() + + func_res = f"{name}_tmp" + return_names = cpp.return_names(f) + if len(return_names) > 1: + updates = "\n ".join( + f"{copy_op}(std::get<{i}>({func_res}), {ret_name});" + for i, ret_name in enumerate(return_names) + ) + returns = f"{sig.returns_type().cpp_type()}({', '.join(return_names)})" + elif len(return_names) == 1: + ret_name = return_names[0] + updates = f"{copy_op}({func_res}, {ret_name});" + returns = ret_name + else: + assert len(f.func.arguments.out) == 1 + returns = "" + out_arg = f.func.arguments.out[0] + if out_arg.type.is_list_like(): + updates = f"""\ + for (int64_t i = 0; i < {func_res}.size(); ++i) {{ + {copy_op}({func_res}[i], {out_arg.name}[i]); + }}""" + else: + updates = f"{copy_op}({func_res}, {out_arg.name});" + + functional_sig = self.wrapper_kernel_sig(g.functional) + wrapper_name = sig.name() + + return f"""\ +{sig.defn(name=wrapper_name)} {{ + auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))}); + {updates} + return {returns}; +}} +""" + + def gen_structured(self, g: NativeFunctionsGroup) -> list[str]: + metadata = self.backend_index.get_kernel(g) + if self.backend_index.dispatch_key == DispatchKey.Meta: + assert not self.backend_index.has_kernel(g.out), ( + "Do not explicitly specify Meta dispatch key on structured " + "functions, they will be automatically generated for you" + ) + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + assert not self.backend_index.has_kernel(g.out), ( + "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " + "functions, they will be automatically generated for you" + ) + elif metadata is None or not metadata.structured: + return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) + structured_gen = StructuredRegisterDispatchKey( + self.backend_index, + self.target, + self.selector, + self.rocm, + self.symint, + self.class_method_name, + self.skip_dispatcher_op_registration, + g, + ) + return list(mapMaybe(structured_gen.gen_one, g.functions())) + + def gen_unstructured( + self, f: NativeFunction, g: NativeFunctionsGroup | None = None + ) -> str | None: + with native_function_manager(f): + inplace_meta = False + gets_out_inplace_wrapper = False + if not self.backend_index.has_kernel(f): + if ( + self.backend_index.dispatch_key == DispatchKey.Meta + and f.func.kind() is SchemaKind.inplace + and + # Defer to composites for meta implementation + not f.has_composite_kernel + and + # Inplace list operations are not supported + len(f.func.returns) == 1 + ): + inplace_meta = True + elif ( + not self.backend_index.use_out_as_primary + and g is not None + and gets_generated_out_inplace_wrapper(f, g, self.backend_index) + ): + # We want to generate inplace/out wrappers, that don't have a kernel for the backend. + gets_out_inplace_wrapper = True + else: + return None + if f.manual_kernel_registration: + return None + + if ( + self.target is Target.REGISTRATION + and not self.selector.is_native_function_selected(f) + ): + return None + + sig = self.wrapper_kernel_sig(f) + + name = sig.name() + returns_type = sig.returns_type().cpp_type() + args = sig.arguments() + args_str = ", ".join(a.defn() for a in args) + + # See Note [Direct dispatch bindings] + cpp_sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + + # TODO: dedupe this with the structured codegen + if self.target is Target.NAMESPACED_DECLARATION: + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += f"TORCH_API {cpp_sig.decl()};\n" + return result + elif self.target is Target.NAMESPACED_DEFINITION: + + def generate_defn(cpp_sig: CppSignature) -> str: + return f""" +{cpp_sig.defn()} {{ +return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); +}} +""" + + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += generate_defn(cpp_sig) + return result + + elif self.target is Target.ANONYMOUS_DEFINITION: + # short circuit for inplace_meta + if inplace_meta: + assert f.func.arguments.self_arg is not None + self_arg_name = f.func.arguments.self_arg.argument.name + # TODO: handle in place on tensor list + return f""" +{returns_type} {name}({args_str}) {{ + TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), + "Cannot inplace into non-meta tensor with meta tensor argument"); + return {self_arg_name}; +}} +""" + + # short circuit for generated inplace/out wrappers + if gets_out_inplace_wrapper: + return self.gen_out_inplace_wrapper(f, g) + + metadata = self.backend_index.get_kernel(f) + if metadata is None: + return None + if self.class_method_name is None: + impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}" + else: + impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" + + kernel_sig = kernel_signature(f, self.backend_index) + + args_exprs_str = ", ".join( + e.expr + for e in translate( + sig.arguments(), kernel_sig.arguments(), method=False + ) + ) + + device_check = " // No device check\n" + # Backends that require device guards presumably also require device checks. + if self.backend_index.device_guard: + device_check_args = itertools.chain( + f.func.arguments.out, f.func.arguments.flat_positional + ) + device_check = RegisterDispatchKey.gen_device_check( + f.device_check, list(device_check_args), name + ) + + device_guard = "// DeviceGuard omitted" # default + if f.device_guard and self.backend_index.device_guard: + has_tensor_options = any( + isinstance(a, TensorOptionsArguments) + for a in f.func.arguments.non_out + ) + if has_tensor_options: + # kernel is creating a tensor + device_guard = """ + const DeviceGuard device_guard(device_or_default(device));""" + + # CUDA requires special handling + if is_cuda_dispatch_key(self.backend_index.dispatch_key): + device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}" + else: + # kernel is operating on existing tensors + + # There is precedence for which argument we use to do + # device guard. This describes the precedence order. + self_arg = ( + [f.func.arguments.self_arg.argument] + if f.func.arguments.self_arg is not None + else [] + ) + candidate_args = itertools.chain( + self_arg, + f.func.arguments.out, + f.func.arguments.flat_positional, + ) + + # Only tensor like arguments are eligible + device_of = next( + ( + f"{a.name}" + for a in candidate_args + if a.type.is_tensor_like() + ), + None, + ) + if device_of is not None: + device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" + + return f"""\ +namespace {{ + +{returns_type} {name}({args_str}) {{ + {device_check} + + {device_guard} + return {impl_name}({args_exprs_str}); +}} + +}} // anonymous namespace +""" + + elif self.target is Target.REGISTRATION: + if f.manual_kernel_registration or self.skip_dispatcher_op_registration: + return None + else: + payload = f"TORCH_FN({name})" + return f'm.impl("{f.func.name}",\n{payload});\n' + else: + assert_never(self.target) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# STRUCTURED +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@dataclass(frozen=True) +class StructuredRegisterDispatchKey(RegisterDispatchKey): + g: NativeFunctionsGroup + + def gen_class_set_output_functions( + self, k: SchemaKind, parent_class: str, generate_super: bool + ) -> str: + if generate_super: + set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);" + else: + set_output_super = "" + + def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: + return f""" +void set_output_{name}( + int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, + TensorOptions options, DimnameList names +) override {{ +{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")} + if (!names.empty()) {{ + namedinference::propagate_names(outputs_[output_idx], names); + }} + // super must happen after, so that downstream can use maybe_get_output + // to retrieve the output +{textwrap.indent(set_output_super, " ")} +}} +""" + + return f""" +{gen_set_output_function("strided", maybe_create_proxy=True)} +{gen_set_output_function("raw_strided", maybe_create_proxy=False)} +""" + + def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str: + if self.backend_index.dispatch_key in [ + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.XPU, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ]: + maybe_set_guard = """ +auto current_device = guard_.current_device(); +if (C10_UNLIKELY(current_device.has_value())) { + TORCH_INTERNAL_ASSERT(*current_device == options.device(), + "structured kernels don't support multi-device outputs"); +} else { + guard_.reset_device(options.device()); +} +""" + maybe_set_guard_line = maybe_set_guard + "\n" + else: + maybe_set_guard_line = maybe_set_guard = "" + + if maybe_create_proxy: + create_proxy = """ +auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); +if (C10_UNLIKELY(maybe_proxy.has_value())) { + proxy_outputs_[output_idx] = std::move(maybe_proxy).value(); +} +""" + else: + create_proxy = "" + + if k is SchemaKind.functional: + assert self.backend_index.dispatch_key in ( + DispatchKey.Meta, + DispatchKey.CPU, + DispatchKey.CUDA, + DispatchKey.MPS, + DispatchKey.XPU, + DispatchKey.MTIA, + DispatchKey.CompositeExplicitAutogradNonFunctional, + ) + return f"""{maybe_set_guard_line} +outputs_[output_idx] = create_out(sizes, strides, options);""" + elif k is SchemaKind.inplace: + return f"""{maybe_set_guard_line} +const auto& out = outputs_[output_idx].get(); +check_inplace(out, sizes, options); +{create_proxy}""" + elif k is SchemaKind.out: + return f"""{maybe_set_guard_line} +const auto& out = outputs_[output_idx].get(); +resize_out(out, sizes, strides, options); +{create_proxy}""" + elif k is SchemaKind.mutable or k is SchemaKind.scratch: + raise AssertionError( + f"{k} structured operators are currently not supported" + ) + else: + assert_never(k) + + # returns the definition of a ctor, as well as how to construct + # this class to a variable named op + def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: + if k is SchemaKind.functional: + return "" + elif k is SchemaKind.inplace: + # TODO: Make sure out argument is guaranteed to be self + return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" + elif k is SchemaKind.out: + out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) + out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) + return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" + elif k is SchemaKind.mutable or k is SchemaKind.scratch: + raise AssertionError( + f"{k} structured operators are currently not supported" + ) + else: + assert_never(k) + + def gen_class( + self, + f: NativeFunction, + k: SchemaKind, + *, + class_name: str, + parent_class: str, + generate_super: bool, + ) -> str: + if k is SchemaKind.functional: + output_type = "Tensor" + output_value = "outputs_[output_idx]" + proxy_field = "" + elif k is SchemaKind.inplace: + output_type = "std::reference_wrapper" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array<::std::optional, {len(f.func.returns)}> proxy_outputs_;" + elif k is SchemaKind.out: + output_type = "std::reference_wrapper" + output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" + proxy_field = f"std::array<::std::optional, {len(f.func.returns)}> proxy_outputs_;" + else: + raise RuntimeError(f"Unsupported SchemaKind {k}") + + if self.backend_index.dispatch_key == DispatchKey.CUDA: + if self.rocm: + guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;" + else: + guard_field = "c10::cuda::OptionalCUDAGuard guard_;" + elif ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + guard_field = "c10::OptionalDeviceGuard guard_;" + elif self.backend_index.dispatch_key == DispatchKey.MPS: + # TODO: Move to OptionalMPSGuard. + guard_field = "c10::OptionalDeviceGuard guard_;" + elif self.backend_index.dispatch_key == DispatchKey.XPU: + guard_field = "c10::OptionalDeviceGuard guard_;" + elif self.backend_index.dispatch_key == DispatchKey.MTIA: + guard_field = "c10::OptionalDeviceGuard guard_;" + else: + guard_field = "" + + indent = " " * 4 + class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns)) + lines = ( + f"struct {class_name} final : public {parent_class} {{", + f"{textwrap.indent(class_ctor_str, indent)}", + f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}", + " const Tensor& maybe_get_output(int64_t output_idx) override {", + f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit + " }", + # type: ignore[possibly-undefined] # TODO: audit + f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", + f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit + f"{textwrap.indent(guard_field, indent)}", + "};", + ) + return "\n".join(line for line in lines if line) + + @method_with_native_function + def gen_one(self, f: NativeFunction) -> str | None: + assert not f.manual_kernel_registration + + if ( + self.target is Target.REGISTRATION + and not self.selector.is_native_function_selected(f) + ): + return None + + # TODO: Now, there is something interesting going on here. In the code below, + # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace + # based on the out implementation. But in fact, out is definable by + # functional too (just not very efficiently), and this is honestly the + # MORE likely situation for a backend implementer. How do we pick? + # Well, taking a page from Haskell type classes and default methods, + # we could conceivably register a circular definition (out in terms + # of functional, and functional in terms of out) and just require + # someone to implement one or the other. We'd have to do a little bit + # of work to not register one of these "weak" definitions unless there + # is a strong definition somewhere in the DAG! So it's not implemented yet. + if ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + and f.func.kind() is SchemaKind.out + ): + # Never generate a default implementation for out, that's what you + # have to define as a backend implementer + return None + + # Note [Direct dispatch bindings] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Signature of the non-dispatched function we'll expose in a header + # (e.g., at::cpu::add). We don't generate methods (TODO: do this + # when CPUTensor class is a thing); nor do we generate fallback + # bindings for manual_cpp_binding functions. + cpp_sig_group = CppSignatureGroup.from_native_function( + f, method=False, fallback_binding=False + ) + + # Signature of the wrapper function we'll register to the dispatcher + kern = self.backend_index.get_kernel(f) + sig = NativeSignature( + f.func, + prefix=f"wrapper_{self.backend_index.dispatch_key}_", + symint=kern is not None and kern.supports_symint(), + ) + + if self.target is Target.NAMESPACED_DECLARATION: + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += f"TORCH_API {cpp_sig.decl()};\n" + return result + + elif self.target is Target.NAMESPACED_DEFINITION: + + def generate_defn(cpp_sig: CppSignature) -> str: + return f""" +{cpp_sig.defn()} {{ +return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); +}} +""" + + result = "" + for cpp_sig in cpp_sig_group.signatures(symint=self.symint): + result += generate_defn(cpp_sig) + return result + + elif self.target is Target.ANONYMOUS_DEFINITION: + k = f.func.kind() + + # Construct the body of the wrapper function with signature sig + sig_body = [] + # We'll use context to keep track of any variables we've brought + # into scope while generating code + context: list[Binding | Expr] = list(sig.arguments()) + + # Initialize the class corresponding to this structured + # operator; feeding it the output argument(s) if it is known + if self.backend_index.dispatch_key is DispatchKey.Meta: + class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" + parent_class = f"at::meta::structured_{meta.name(self.g)}" + elif ( + self.backend_index.dispatch_key + is DispatchKey.CompositeExplicitAutogradNonFunctional + ): + # TODO: dedup this branch + class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" + parent_class = f"at::meta::structured_{meta.name(self.g)}" + else: + metadata = self.backend_index.get_kernel(self.g) + assert metadata is not None + class_name = f"structured_{metadata.kernel}_{k.name}" + parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}" + + if self.backend_index.device_guard: + device_check_args = itertools.chain( + f.func.arguments.out, f.func.arguments.flat_positional + ) + sig_body.append( + RegisterDispatchKey.gen_device_check( + f.device_check, list(device_check_args), sig.name() + ) + ) + + if k is SchemaKind.functional: + sig_body.append(f"{class_name} op;") + elif k is SchemaKind.inplace: + sig_body.append(f"{class_name} op(self);") + elif k is SchemaKind.out: + out_args_str = ", ".join(a.name for a in f.func.arguments.out) + sig_body.append(f"{class_name} op({out_args_str});") + + # Translate the input native arguments into structured + # arguments for the meta call + meta_exprs = ", ".join( + e.expr + for e in translate( + context, structured.meta_arguments(self.g), method=False + ) + ) + + if self.g.out.precomputed: + # If this function group has precomputed elements, the meta function + # returns a struct containing them which must be saved so that it + # can be unpacked when generating code to call the impl. + sig_body.append(f"auto precompute = op.meta({meta_exprs});") + + # Put all of the contents of the precompute struct into the context + # so that translate will be able to return the correct args for the + # call to the impl. + precomputed_values = [ + *self.g.out.precomputed.replace.values(), + self.g.out.precomputed.add, + ] + for precomputed_elems in precomputed_values: + context.extend( + Expr( + expr=f"precompute.{arg.name}", + type=structured.argument_type(arg, binds=arg.name), + ) + for arg in precomputed_elems + ) + + # Add a use of the precompute struct so FB internal compilers don't + # complain that there is an unused variable. + sig_body.append("(void)precompute;") + else: + sig_body.append(f"op.meta({meta_exprs});") + + # After running meta, op.outputs_ is guaranteed to be valid; + # add it to the context + out_args = structured.out_arguments(self.g) + for i, out_arg in enumerate(out_args): + assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type + + if k is SchemaKind.out: + expr = f"op.maybe_get_output({i})" + else: + expr = f"op.outputs_[{i}]" + + context.append( + Expr( + expr=expr, + # TODO: Stop hardcoding that the output type is a Tensor. Note + # that for the codegen here this is fine because outputs_ is + # hardcoded to be tensor already + type=NamedCType( + out_arg.nctype.name, MutRefCType(BaseCType(tensorT)) + ), + ) + ) + + # With the expanded context, do the impl call (if not a meta + # function) + if ( + self.backend_index.dispatch_key + == DispatchKey.CompositeExplicitAutogradNonFunctional + ): + # TODO: https://github.com/pytorch/pytorch/issues/53023 + out_sig_group = CppSignatureGroup.from_native_function( + self.g.out, method=False, fallback_binding=f.manual_cpp_binding + ) + out_sig = out_sig_group.most_faithful_signature() + api_name = out_sig.name() + out_exprs = ", ".join( + e.expr + for e in translate(context, out_sig.arguments(), method=False) + ) + # TODO: I think this means structured won't work with method + # only functions (but maybe you're saved by faithful? iunno.) + # NB: Originally I wrote this as an at::redispatch call, but + # I got in trouble because that meant I needed a DispatchKeySet + # in the wrapper function, which meant I needed a DispatchKeySet + # in the DispatchKeyFunctions declarations, but the defined API + # there does NOT permit a dispatch key set. I think you can + # probably unwind this by calling some function to do the TLS + # fetch and get the DispatchKeySet when you don't have it, but + # I didn't do it for this version + sig_body.append(f"at::{api_name}({out_exprs});") + elif self.backend_index.dispatch_key != DispatchKey.Meta: + impl_exprs = ", ".join( + e.expr + for e in translate( + context, structured.impl_arguments(self.g), method=False + ) + ) + sig_body.append(f"op.impl({impl_exprs});") + + # Go over each output, and check if there is a proxy created for it. + # If so, copy it over to the original output. + if k is SchemaKind.out or k is SchemaKind.inplace: + for i in range(len(f.func.returns)): + sig_body.append( + f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);" + ) + + # Destructively return the final tensors + # TODO: Do this in translate instead + if k is SchemaKind.functional: + if len(f.func.returns) == 1: + ret_expr = "std::move(op.outputs_[0])" # small optimization + else: + moved = ", ".join( + f"std::move(op.outputs_[{i}])" + for i in range(len(f.func.returns)) + ) + ret_expr = f"std::make_tuple({moved})" + elif k is SchemaKind.inplace: + ret_expr = "self" + elif k is SchemaKind.out: + if len(f.func.returns) == 1: + ret_expr = f.func.arguments.out[0].name + else: + refs = ", ".join(a.name for a in f.func.arguments.out) + ret_expr = f"std::forward_as_tuple({refs})" + sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit + + sig_body_str = "\n".join(sig_body) + + # For an overview of what this template code looks like, see + # https://github.com/pytorch/rfcs/pull/9 + return f"""\ +{ + self.gen_class( + f, + k, + class_name=class_name, + parent_class=parent_class, + generate_super=self.g.out.structured_inherits is not None, + ) + } + +{sig.defn()} {{ +{sig_body_str} +}} +""" + + elif self.target is Target.REGISTRATION: + return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' + else: + assert_never(self.target) + # Silence mypy's "Missing return statement" error + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/ufunc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..045d8de110e7442d0732aee483f0aab7015140d7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/ufunc.py @@ -0,0 +1,553 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torchgen.api.ufunc as ufunc +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + CType, + Expr, + NamedCType, + opmath_t, + scalar_t, + StructuredImplSignature, + VectorizedCType, +) +from torchgen.context import with_native_function +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, + NativeFunctionsGroup, + ScalarType, + UfuncKey, +) +from torchgen.utils import OrderedSet + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torchgen.api.ufunc import UfunctorBindings + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CUDA STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# NB: not bothering to generate dispatch stub forward declaration in header, +# we can just paste it wherever necessary + +# TODO: use BackendIndex +# dispatch_key: DispatchKey # only CPU/CUDA right now + + +# Represents functors for implementing CUDA ufuncs. +# Functors are templated by scalar_t because when USERS instantiate functors +# they are templated. A functor looks something like this: +# +# template +# struct CUDAFunctorOnSelf_add { +# using opmath_t = at::opmath_type; +# opmath_t other_; +# opmath_t alpha_; +# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) +# : other_(other), alpha_(alpha) {} +# __device__ scalar_t operator()(scalar_t self) { +# return ufunc::add(static_cast(self), other_, alpha_); +# } +# }; +# +@dataclass(frozen=True) +class UfunctorSignature: + g: NativeFunctionsGroup + scalar_tensor_idx: int | None + name: str + + def arguments(self) -> UfunctorBindings: + return ufunc.ufunctor_arguments( + self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t + ) + + def fields(self) -> list[Binding]: + # fields are renamed to have a trailing underscore, as is conventional + return [b.rename(f"{b.name}_") for b in self.arguments().ctor] + + def returns_type(self) -> CType: + # TODO: don't hardcode; return type will be inferred based on tags on + # the native function + return BaseCType(scalar_t) + + def decl_fields(self) -> str: + return "\n".join(f"{f.type} {f.name};" for f in self.fields()) + + def inline_defn_ctor(self) -> str: + args_str = ", ".join(a.decl() for a in self.arguments().ctor) + # NB: hypothetically could do this with translate but the + # transition here is very regular + init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor) + return f"{self.name}({args_str}) : {init_str} {{}}" + + def decl_apply(self) -> str: + args_str = ", ".join(a.decl() for a in self.arguments().apply) + return f"{self.returns_type().cpp_type()} operator()({args_str}) const" + + +@dataclass(frozen=True) +class UfuncSignature: + g: NativeFunctionsGroup + name: str + compute_t: CType + + def arguments(self) -> list[Binding]: + return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) + + def call(self, ctx: Sequence[Binding | Expr]) -> str: + return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +# steps: +# 1. take the functional signature +# 2. use api.ufunc to convert it to template signature. this establishes +# the type of the template function +# 3. use api.ufunc (II) to generate a split struct / operator() signature. +# this establish context in which we call the template signature +# +# StructuredImplSignature context +# ~> functor constructor sig +# +# Functor constructor context +# ~> functor fields sig +# +# Functor apply context (functor fields + functor apply sig) +# ~> template sig +# + + +def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: + num_tensors = sum( + 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like() + ) + return num_tensors == 2 + + +def compute_ufunc_cuda_functors( + g: NativeFunctionsGroup, +) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]: + # First, build the functors. + ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {} + ufunctors: list[str] = [] + loops = g.out.ufunc_inner_loop + scalar_tensor_idx_lookup = { + UfuncKey.CUDAFunctorOnSelf: 1, + UfuncKey.CUDAFunctorOnOther: 0, + UfuncKey.CUDAFunctor: None, + } + if eligible_for_binary_scalar_specialization(g): + keys = [ + UfuncKey.CUDAFunctorOnSelf, + UfuncKey.CUDAFunctorOnOther, + UfuncKey.CUDAFunctor, + ] + else: + keys = [UfuncKey.CUDAFunctor] + for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: + assert k not in loops, f"cannot use {k} on non-binary function" + for k in keys: + # If the key was directly defined, skip functor codegen; we assume the + # user already done it for us + if k in loops: + ufunctor_sig = UfunctorSignature( + g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name + ) + for dtype in loops[k].supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + continue + + # Note [ScalarOnly and Generic must match names for CUDA] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Otherwise, look in ANY of the generic entries. For simplicity of + # codegen, both ScalarOnly and Generic are defined, the ufunc name + # must match (if they didn't match, we'd have to generate distinct + # functors per dtype, which is awful, so we're not going to do it unless + # someone really forces us to) + ufunc_name = None + supported_dtypes: OrderedSet[ScalarType] = OrderedSet() + for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: + if lk not in loops: + continue + if ufunc_name is None: + ufunc_name = loops[lk].name + else: + # See Note [ScalarOnly and Generic must match names for CUDA] + assert ufunc_name == loops[lk].name, ( + "ScalarOnly and Generic must have same ufunc name" + ) + supported_dtypes |= loops[lk].supported_dtypes + assert ufunc_name is not None + + name = f"{k}_{ufunc_name}" + ufunctor_sig = UfunctorSignature( + g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name + ) + for dtype in supported_dtypes: + ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig + + ufunc_sig = UfuncSignature( + g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t) + ) + apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply + ufunctors.append( + f""" +template +struct {ufunctor_sig.name} {{ + using opmath_t = at::opmath_type; + {ufunctor_sig.decl_fields()} + {ufunctor_sig.inline_defn_ctor()} + __device__ {ufunctor_sig.decl_apply()} {{ + return {ufunc_sig.call(apply_ctx)}; + }} +}}; +""" + ) + + return ufunctor_sigs, "\n".join(ufunctors) + + +@dataclass(frozen=True) +class BinaryScalarSpecializationConfig: + scalar_idx: int + ctor_tensor: str + ufunc_key: UfuncKey + + +BinaryScalarSpecializationConfigs = [ + BinaryScalarSpecializationConfig( + scalar_idx=0, + ctor_tensor="self", + ufunc_key=UfuncKey.CUDAFunctorOnOther, + ), + BinaryScalarSpecializationConfig( + scalar_idx=1, + ctor_tensor="other", + ufunc_key=UfuncKey.CUDAFunctorOnSelf, + ), +] + + +def compute_ufunc_cuda_dtype_body( + g: NativeFunctionsGroup, + dtype: ScalarType, + inner_loops: dict[UfuncKey, UfunctorSignature], + parent_ctx: Sequence[Binding], +) -> str: + body = "using opmath_t = at::opmath_type;" + body += "if (false) {}\n" # for ease of codegen + for config in BinaryScalarSpecializationConfigs: + if config.ufunc_key not in inner_loops: + continue + ufunctor_sig = inner_loops[config.ufunc_key] + scalar_idx = config.scalar_idx + 1 + # Make a copy and at the same time widen the type (not permissible + # without copy; we don't want to mutate the input argument anyway) + ctx: list[Expr | Binding] = list(parent_ctx) + ctx.append( + Expr( + expr=f"iter.scalar_value({scalar_idx})", + type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), + ) + ) + ufunctor_ctor_exprs_str = ", ".join( + a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) + ) + + # NB: ufunctor must be allocated before iter.remove_operand is called, + # as it relies on iter + body += f"""\ +else if (iter.is_cpu_scalar({scalar_idx})) {{ + {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str}); + iter.remove_operand({scalar_idx}); + gpu_kernel(iter, ufunctor); +}}""" + + ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] + ufunctor_ctor_exprs_str = ", ".join( + a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) + ) + body += f""" +else {{ + gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str})); +}} + """ + return body + + +@with_native_function +def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: + # First, build the functors, indexing them by dtype + ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) + + # Next, build the conditionals + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): + dtype_cases.append( + f""" +AT_DISPATCH_CASE(at::ScalarType::{dtype}, + [&]() {{ + {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} + }} +) +""" + ) + + dtype_cases_str = "\n".join(dtype_cases) + + stub_sig = StubSignature(g) + + return f""" +{ufunctors} + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()} + +{stub_sig.kernel_defn()} {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", + {dtype_cases_str} + ); +}} +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}) + +{sig.defn()} {{ + {stub_sig.direct_call(sig.arguments())}; +}} +""" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# CPU STUFF +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@dataclass(frozen=True) +class StubSignature: + g: NativeFunctionsGroup + + @property + def name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_stub" + + @property + def kernel_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_kernel" + + @property + def type_name(self) -> str: + return f"{str(self.g.functional.func.name.name)}_fn" + + def arguments(self) -> list[Binding]: + return ufunc.stub_arguments(self.g) + + def type(self) -> str: + cpp_args = self.arguments() + return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" + + def dispatch_decl(self) -> str: + return f"DECLARE_DISPATCH({self.type_name}, {self.name})" + + def dispatch_defn(self) -> str: + return f"DEFINE_DISPATCH({self.name})" + + def kernel_defn(self) -> str: + return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" + + def type_defn(self) -> str: + return f"using {self.type_name} = {self.type()}" + + # must be called from context where this is TensorIteratorBase* + def call(self, ctx: Sequence[Binding]) -> str: + return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + # used in CUDA to skip the unnecessary dynamic dispatch + def direct_call(self, ctx: Sequence[Binding]) -> str: + return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" + + +@with_native_function +def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) + + return f""" +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()} +{stub_sig.dispatch_defn()}; + +{sig.defn()} {{ + {stub_sig.call(sig.arguments())}; +}} +""" + + +def compute_ufunc_cpu_dtype_body( + g: NativeFunctionsGroup, + dtype: ScalarType, + inner_loops: dict[UfuncKey, UfuncSignature], + parent_ctx: Sequence[Binding], +) -> str: + assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" + assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} + scalar_loop = inner_loops[UfuncKey.CPUScalar] + vec_loop = None + if UfuncKey.CPUVector in inner_loops: + vec_loop = inner_loops[UfuncKey.CPUVector] + + # NB: We DON'T use translate here, because translate is + # incapable of CSE'ing the scalar accesses in case it is also + # used by Vectorized; also, the unpacking here is very simple + # and only affects Scalar; everything else is implicitly captured + # by the lambda + + # Setup scalar in scope + body = [] + ctx = [] + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType( + BaseTy.Scalar + ): + continue + body.append(f"auto _s_{b.name} = {b.name}.to();") + ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) + if vec_loop is not None: + for b in parent_ctx: + if isinstance(b.argument, Argument) and b.argument.type != BaseType( + BaseTy.Scalar + ): + continue + body.append( + f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});" + ) + ctx.append( + Expr( + f"_v_{b.name}", + NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), + ) + ) + + # Setup lambda signature + # NB: simplified version of ufunctor_arguments + scalar_bindings = [] + vec_bindings = [] + for a in g.functional.func.arguments.flat_non_out: + if not a.type.is_tensor_like(): + continue + assert a.type == BaseType(BaseTy.Tensor) + scalar_bindings.append( + Binding( + name=a.name, + nctype=NamedCType(a.name, BaseCType(scalar_t)), + argument=a, + ) + ) + if vec_loop is not None: + vec_bindings.append( + Binding( + name=a.name, + nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), + argument=a, + ) + ) + + def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]: + r: list[Expr | Binding] = [] + r.extend(ctx) + r.extend(b) + return r + + body_str = "\n".join(body) + if vec_loop is not None: + return f""" +{body_str} +cpu_kernel_vec(iter, + [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, + [=]({", ".join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} +); +""" + else: + return f""" +{body_str} +cpu_kernel(iter, + [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} +); +""" + + +@with_native_function +def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: + stub_sig = StubSignature(g) + + # Reindex the ufunc by dtypes; processing generic/scalaronly as well + loops = g.out.ufunc_inner_loop + ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {} + for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: + lks = [] + # ORDER MATTERS: this specifies overriding precedence + if k in loops: # should happen rarely + lks.append(k) + if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: + lks.append(UfuncKey.ScalarOnly) + if UfuncKey.Generic in loops: + lks.append(UfuncKey.Generic) + # TODO: don't hardcode ufunc:: namespace here, should be centralized smh + for lk in lks: + for dtype in loops[lk].supported_dtypes: + compute_t: CType + if k is UfuncKey.CPUScalar: + compute_t = BaseCType(scalar_t) + elif k is UfuncKey.CPUVector: + compute_t = VectorizedCType(BaseCType(scalar_t)) + else: + raise AssertionError + inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) + if k not in inner_ufunc_sigs: + inner_ufunc_sigs[k] = UfuncSignature( + g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t + ) + + # Build the conditionals + dtype_cases = [] + for dtype, inner_ufunc_sigs in ufunc_sigs.items(): + dtype_cases.append( + f""" +AT_DISPATCH_CASE(at::ScalarType::{dtype}, + [&]() {{ + {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} + }} +) +""" + ) + + dtype_cases_str = "\n".join(dtype_cases) + return f""" +namespace {{ + +{stub_sig.kernel_defn()} {{ + AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}", + {dtype_cases_str} + ); +}} + +}} // anonymous namespace + +{stub_sig.type_defn()}; +{stub_sig.dispatch_decl()} +REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}) +""" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py new file mode 100644 index 0000000000000000000000000000000000000000..15b74ac9c21a70d3f97df0dae210087072c15142 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import os +from enum import Enum +from operator import itemgetter +from pathlib import Path +from typing import Any + +import torch +from torch.jit.generate_bytecode import generate_upgraders_bytecode +from torchgen.code_template import CodeTemplate +from torchgen.operator_versions.gen_mobile_upgraders_constant import ( + MOBILE_UPGRADERS_HEADER_DESCRIPTION, +) + + +class ByteCode(Enum): + instructions = 1 + constants = 2 + types = 3 + operators = 4 + register_size = 5 + + +EXCLUDED_OP_SET = [ + "aten::full.names", + "aten::full.out", + "aten::full", +] + +EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"] + +ONE_INSTRUCTION = CodeTemplate( + """ + Instruction{OpCode::${operator_name}, ${X}, ${N}},""" +) + +INSTRUCTION_LIST = CodeTemplate( + """std::vector({ + ${instruction_list} + }), // instructions list""" +) + +ONE_CONSTANT = CodeTemplate( + """ + c10::IValue(${constant}),""" +) + +CONSTANT_LIST = CodeTemplate( + """std::vector({ + ${constant_list} + }), // constants list""" +) + +CONSTANTS_LIST_EMPTY = """std::vector(), // constants list""" + +ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""") + +TYPE_LIST = CodeTemplate( + """std::vector({ + ${type_list} + }), // types list""" +) + +TYPE_LIST_EMPTY = """std::vector(), // types list""" + +ONE_OPERATOTR_STRING = CodeTemplate( + """ + OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),""" +) + +OPERATOR_STRING_LIST = CodeTemplate( + """ + std::vector({ + ${operator_string_list} + }), // operators list""" +) + +ONE_UPGRADER_FUNCTION = CodeTemplate( + """ + mobile::Function::registerFunc( + "${upgrader_name}", + ${instruction_list}, + ${constant_list}, + ${type_list}, + ${register_size} + )""" +) + +ONE_UPGRADER_SRC = CodeTemplate( + """ + ByteCodeFunctionWithOperator({ + ${bytecode_function}, + ${operator_string_list} + }),""" +) + + +ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate( + """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})""" +) # noqa: E501 + +ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate( + """ + {std::string("${operator_name}"), + std::vector({ + ${upgrader_list_in_version_map} + })},""" +) + + +OPERATOR_VERSION_MAP = CodeTemplate( + """ +const std::unordered_map> +getOperatorVersionMapForMobile() { + static std::unordered_map> + operatorVersionMapForMobile({ + ${operator_list_in_version_map} + }); + return operatorVersionMapForMobile; +} +""" +) + + +UPGRADER_CPP_SRC = CodeTemplate( + MOBILE_UPGRADERS_HEADER_DESCRIPTION + + """ +#include +#include +#include + +namespace torch { +namespace jit { + +// clang-format off + +// From operator_versions_map +${operator_version_map} + +const std::vector& getUpgraderBytecodeList() { + auto generate_upgrader_bytecode_list = []() { + std::vector upgrader_function_list({ + ${upgrader_bytecode} + }); + for (const auto& upgrader_function : upgrader_function_list) { + for (const auto& op : upgrader_function.operators) { + upgrader_function.function.append_operator( + op.name, + op.overload_name, + op.num_specified_args); + } + } + return upgrader_function_list; + }; + static std::vector upgraderBytecodeList = + generate_upgrader_bytecode_list(); + return upgraderBytecodeList; +} + +// clang-format on + +} // namespace jit +} // namespace torch +""" +) + +UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp" + +UPGRADER_ELEMENT = CodeTemplate( + """\ +Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}), +""" +) + +PER_OPERATOR_UPGRADER_LIST = CodeTemplate( + """\ +{ + std::string(${operator_name}), + std::vector({${upgrader_list}}); +} +""" +) + + +def construct_instruction(instruction_list_from_yaml: list[Any]) -> str: + instruction_list_part = [ + ONE_INSTRUCTION.substitute( + operator_name=instruction[0], + X=instruction[1], + N=instruction[2], + ) + for instruction in instruction_list_from_yaml + ] + return INSTRUCTION_LIST.substitute( + instruction_list="".join(instruction_list_part).lstrip("\n") + ) + + +def construct_constants(constants_list_from_yaml: list[Any]) -> str: + constants_list_part = [] + for constant_from_yaml in constants_list_from_yaml: + convert_constant = None + if isinstance(constant_from_yaml, str): + # Add quotes if it's string + convert_constant = f'"{constant_from_yaml}"' + elif isinstance(constant_from_yaml, bool): + convert_constant = "true" if constant_from_yaml else "false" + elif constant_from_yaml is None: + convert_constant = "" + elif isinstance(constant_from_yaml, int): + convert_constant = str(constant_from_yaml) + else: + raise ValueError( + f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. " + "Please add change in construct_constants function in gen_mobile_upgraders.py." + ) + constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant)) + if len(constants_list_part) == 0: + return CONSTANTS_LIST_EMPTY + return CONSTANT_LIST.substitute( + constant_list="".join(constants_list_part).lstrip("\n") + ) + + +def construct_operators(operator_list_from_yaml: list[Any]) -> str: + operator_list_part = [ + ONE_OPERATOTR_STRING.substitute( + operator_name=operator[0], + overload_name=operator[1], + num_of_args=operator[2], + ) + for operator in operator_list_from_yaml + ] + return OPERATOR_STRING_LIST.substitute( + operator_string_list="".join(operator_list_part).lstrip("\n") + ) + + +def construct_types(types_tr_list_from_yaml: list[Any]) -> str: + types_tr_list_part = [ + ONE_TYPE.substitute(type_str=types_tr) for types_tr in types_tr_list_from_yaml + ] + if len(types_tr_list_part) == 0: + return TYPE_LIST_EMPTY + return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n")) + + +def construct_register_size(register_size_from_yaml: int) -> str: + if not isinstance(register_size_from_yaml, int): + raise ValueError( + f"Input register size is {register_size_from_yaml} and" + "it's type is {type(register_size_from_yaml)}. An int type is expected." + ) + return str(register_size_from_yaml) + + +def construct_version_maps( + upgrader_bytecode_function_to_index_map: dict[str, Any], +) -> str: + version_map = torch._C._get_operator_version_map() + sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] + sorted_version_map = dict(sorted_version_map_) + + operator_list_in_version_map_part = [] + for op_name in sorted_version_map: + upgraders_in_version_map_part = [] + # TODO: remove the skip after these two operators schemas are fixed + if op_name in EXCLUDED_OP_SET: + continue + upgrader_ranges = torch._C._get_upgrader_ranges(op_name) + upgrader_entries = sorted_version_map[op_name] + assert len(upgrader_ranges) == len(upgrader_entries) + for idx, upgrader_entry in enumerate(upgrader_entries): + upgrader_name = upgrader_entry.upgrader_name + bytecode_function_index = upgrader_bytecode_function_to_index_map[ + upgrader_name + ] + upgraders_in_version_map_part.append( + ONE_UPGRADER_IN_VERSION_MAP.substitute( + upgrader_min_version=upgrader_ranges[idx].min_version, + upgrader_max_version=upgrader_ranges[idx].max_version, + upgrader_name=upgrader_name, + bytecode_func_index=bytecode_function_index, + ) + ) + operator_list_in_version_map_part.append( + ONE_OPERATOR_IN_VERSION_MAP.substitute( + operator_name=op_name, + upgrader_list_in_version_map="".join(upgraders_in_version_map_part), + ) + ) + return OPERATOR_VERSION_MAP.substitute( + operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip( + "\n" + ) + ) + + +def get_upgrader_bytecode_function_to_index_map( + upgrader_dict: list[dict[str, Any]], +) -> dict[str, Any]: + upgrader_bytecode_function_to_index_map = {} + index = 0 + for upgrader_bytecode in upgrader_dict: + for upgrader_name in upgrader_bytecode: + if upgrader_name in EXCLUE_UPGRADER_SET: + continue + upgrader_bytecode_function_to_index_map[upgrader_name] = index + index += 1 + return upgrader_bytecode_function_to_index_map + + +def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None: + upgrader_bytecode_function_to_index_map = ( + get_upgrader_bytecode_function_to_index_map(upgrader_dict) + ) + version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map) + all_upgrader_src_string = [] + for upgrader_bytecode in upgrader_dict: + for upgrader_name, bytecode in upgrader_bytecode.items(): + # TODO: remove the skip after these two operators schemas are fixed + if upgrader_name in EXCLUE_UPGRADER_SET: + continue + instruction_list_str = "" + constant_list_str = "" + type_list_str = "" + register_size_str = "" + operator_list_str = "" + for table_name, contents in bytecode.items(): + element = ByteCode[table_name] + if element is ByteCode.instructions: + instruction_list_str = construct_instruction(contents) + elif element is ByteCode.constants: + constant_list_str = construct_constants(contents) + elif element is ByteCode.operators: + operator_list_str = construct_operators(contents) + elif element is ByteCode.types: + type_list_str = construct_types(contents) + elif element is ByteCode.register_size: + register_size_str = construct_register_size(contents) + + one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute( + upgrader_name=upgrader_name, + instruction_list=instruction_list_str, + constant_list=constant_list_str, + type_list=type_list_str, + register_size=register_size_str, + ) + one_upgrader_src_string = ONE_UPGRADER_SRC.substitute( + bytecode_function=one_upgrader_function_string.lstrip("\n"), + operator_string_list=operator_list_str.lstrip("\n"), + ) + all_upgrader_src_string.append(one_upgrader_src_string) + + upgrader_file_content = UPGRADER_CPP_SRC.substitute( + operator_version_map=version_map_src, + upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"), + ) + print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME) + with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file: + out_file.write(upgrader_file_content.encode("utf-8")) + + +def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: + sorted_upgrader_list = sorted( + upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader)) + ) + return sorted_upgrader_list + + +def main() -> None: + upgrader_list = generate_upgraders_bytecode() + sorted_upgrader_list = sort_upgrader(upgrader_list) + for up in sorted_upgrader_list: + print("after sort upgrader : ", next(iter(up))) + + pytorch_dir = Path(__file__).resolve().parents[2] + upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile" + write_cpp(str(upgrader_path), sorted_upgrader_list) + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..04b5ad887e54153115eeca7b6686d7c2de8dfc06 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py @@ -0,0 +1,7 @@ +MOBILE_UPGRADERS_HEADER_DESCRIPTION = """/** + * @generated + * This is an auto-generated file. Please do not modify it by hand. + * To re-generate, please run: + * cd ~/pytorch && python torchgen/operator_versions/gen_mobile_upgraders.py + */ +""" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/operator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..8047f033e3d2b0209e03924b355e94a06eceace6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/operator.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +# This class holds information about a single operator used to determine +# the outcome of a selective/custom PyTorch build that doesn't include +# registration code for all the supported operators. This is done to +# reduce the size of the generated binary so that it can be deployed in +# situations where binary size comes at a premium. +# +@dataclass(frozen=True) +class SelectiveBuildOperator: + # The name of the operator. This includes the aten::, etc... prefix + # The operator name may or may not have the overload name. If this + # operator name does not specify an overload name, the way to determine + # if this entry refers to the family of operators with this base name + # or just the operator with this name is to look at the value of the + # 'include_all_overloads' flag in this class. + name: str + + # True if this is a root operator (i.e. called directly from a + # TorchScript model, etc...). An operator is considered to be a + # root operator if it is called directly from any one of the models + # that this instance of the pytorch library was built for. Hence, it + # may not be a root operator in all of the models that are used in + # this instance of the pytorch library. + is_root_operator: bool + + # Is this operator used for on-device training? If True, then we need to + # use the information to generate code in VariableType_N.cpp for registration + # of training related operators. Again, this is True if this operator + # is used for training in one or more models used by this instance of the + # pytorch library. + is_used_for_training: bool + + # If True, it indicates that this operator instance (object) refers to an + # operator without the overload name and should apply to all overloads + # which have this operator name as the base name. This flag is applicable + # only for objects that have operator names without a DOT (period) character + # in them. + # + # Note: This flag is a temporary workaround to grandfather in the current + # static selective (custom) build mechanism, which largely ignores overload + # names when determining whether to select operators for registration + # purposes. + include_all_overloads: bool + + # Debug Information at the operator level + _debug_info: tuple[str, ...] | None + + @staticmethod + def from_yaml_dict( + op_name: str, op_info: dict[str, object] + ) -> SelectiveBuildOperator: + allowed_keys = { + "name", + "is_root_operator", + "is_used_for_training", + "include_all_overloads", + "debug_info", + } + + if len(set(op_info.keys()) - allowed_keys) > 0: + raise Exception( # noqa: TRY002 + "Got unexpected top level keys: {}".format( + ",".join(set(op_info.keys()) - allowed_keys), + ) + ) + + if "name" in op_info: + assert op_name == op_info["name"] + + is_root_operator = op_info.get("is_root_operator", True) + assert isinstance(is_root_operator, bool) + + is_used_for_training = op_info.get("is_used_for_training", True) + assert isinstance(is_used_for_training, bool) + + include_all_overloads = op_info.get("include_all_overloads", True) + assert isinstance(include_all_overloads, bool) + + debug_info: tuple[str, ...] | None = None + if "debug_info" in op_info: + di_list = op_info["debug_info"] + assert isinstance(di_list, list) + debug_info = tuple(str(x) for x in di_list) + + return SelectiveBuildOperator( + name=op_name, + is_root_operator=is_root_operator, + is_used_for_training=is_used_for_training, + include_all_overloads=include_all_overloads, + _debug_info=debug_info, + ) + + @staticmethod + def from_legacy_operator_name_without_overload( + name: str, + ) -> SelectiveBuildOperator: + return SelectiveBuildOperator( + name=name, + is_root_operator=True, + is_used_for_training=True, + include_all_overloads=True, + _debug_info=None, + ) + + def to_dict(self) -> dict[str, object]: + ret: dict[str, object] = { + "is_root_operator": self.is_root_operator, + "is_used_for_training": self.is_used_for_training, + "include_all_overloads": self.include_all_overloads, + } + if self._debug_info is not None: + ret["debug_info"] = self._debug_info + + return ret + + +def merge_debug_info( + lhs: tuple[str, ...] | None, + rhs: tuple[str, ...] | None, +) -> tuple[str, ...] | None: + # Ensure that when merging, each entry shows up just once. + if lhs is None and rhs is None: + return None + + return tuple(set((lhs or ()) + (rhs or ()))) + + +def combine_operators( + lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator +) -> SelectiveBuildOperator: + if str(lhs.name) != str(rhs.name): + raise Exception( # noqa: TRY002 + f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" + ) + + return SelectiveBuildOperator( + name=lhs.name, + # Consider this operator to be a root operator if it is a + # root operator in any of the models used in this instance of + # the pytorch library. + is_root_operator=lhs.is_root_operator or rhs.is_root_operator, + # Consider this operator to be a training operator if it is + # an operator used for training in any of the models used + # in this instance of the pytorch library. + is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, + include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, + _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), + ) + + +def merge_operator_dicts( + lhs: dict[str, SelectiveBuildOperator], + rhs: dict[str, SelectiveBuildOperator], +) -> dict[str, SelectiveBuildOperator]: + operators: dict[str, SelectiveBuildOperator] = {} + for op_name, op in list(lhs.items()) + list(rhs.items()): + new_op = op + if op_name in operators: + new_op = combine_operators(operators[op_name], op) + + operators[op_name] = new_op + + return operators + + +def strip_operator_overload_name(op_name: str) -> str: + return op_name.split(".", maxsplit=1)[0] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/selector.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/selector.py new file mode 100644 index 0000000000000000000000000000000000000000..04acc354203ade2f48dcef56fd9d9ef70c82ad1d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/selector.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import yaml + +from torchgen.selective_build.operator import ( + merge_debug_info, + merge_operator_dicts, + SelectiveBuildOperator, + strip_operator_overload_name, +) + + +if TYPE_CHECKING: + from torchgen.model import NativeFunction + + +# A SelectiveBuilder holds information extracted from the selective build +# YAML specification. +# +# It includes information about the build's selectivity, the debug_info +# associated with this selective build (opaque string), and the set of +# operators that should be included in the build. +# +@dataclass(frozen=True) +class SelectiveBuilder: + # If true, then the build is not selective, and includes all + # operators. + include_all_operators: bool + + # Debug Information at the selective/custom build level. + _debug_info: tuple[str, ...] | None + + # A dictionary of operator -> operator metadata. + operators: dict[str, SelectiveBuildOperator] + + # A dictionary of selected kernel tags and dtypes. Typically a + # PyTorch Operator Kernel (function) may have many code paths + # that are specialized for many many Tensor dtypes, so it's not + # one per kernel function, but there could be many per kernel + # function. The tag isn't a kernel function name, but some fragment + # of the kernel function implementation itself. + kernel_metadata: dict[str, list[str]] + + # ExecuTorch only. A dictionary of kernel tag -> list of (list of input + # dtypes for tensor-like input args). + # This is from selective.yaml + et_kernel_metadata: dict[str, list[str]] + + # A set of all the custom torch bind classes used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + custom_classes: set[str] + + # A set of all the build features used by the selected models + # Stored as a set internally to remove duplicates proactively, but written + # as a list to yamls + build_features: set[str] + + # If true, then fragments for all dtypes for all kernel functions + # are included as well as all custom classes. This is typically set when any one of the + # operator lists is generated from a mechanism other than + # tracing based selective build. + include_all_non_op_selectives: bool + + @staticmethod + def get_nop_selector() -> SelectiveBuilder: + return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) + + @staticmethod + def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder: + valid_top_level_keys = { + "include_all_non_op_selectives", + "include_all_operators", + "debug_info", + "operators", + "kernel_metadata", + "et_kernel_metadata", + "custom_classes", + "build_features", + } + top_level_keys = set(data.keys()) + if len(top_level_keys - valid_top_level_keys) > 0: + raise Exception( # noqa: TRY002 + "Got unexpected top level keys: {}".format( + ",".join(top_level_keys - valid_top_level_keys), + ) + ) + include_all_operators = data.get("include_all_operators", False) + assert isinstance(include_all_operators, bool) + + debug_info = None + if "debug_info" in data: + di_list = data["debug_info"] + assert isinstance(di_list, list) + + debug_info = tuple(str(x) for x in di_list) + + operators = {} + operators_dict = data.get("operators", {}) + assert isinstance(operators_dict, dict) + + for k, v in operators_dict.items(): + operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v) + + kernel_metadata = {} + kernel_metadata_dict = data.get("kernel_metadata", {}) + assert isinstance(kernel_metadata_dict, dict) + + for k, v in kernel_metadata_dict.items(): + kernel_metadata[str(k)] = [str(dtype) for dtype in v] + + et_kernel_metadata = data.get("et_kernel_metadata", {}) + assert isinstance(et_kernel_metadata, dict) + + custom_classes = data.get("custom_classes", []) + assert isinstance(custom_classes, Iterable) + custom_classes = set(custom_classes) + + build_features = data.get("build_features", []) + assert isinstance(build_features, Iterable) + build_features = set(build_features) + + include_all_non_op_selectives = data.get("include_all_non_op_selectives", False) + assert isinstance(include_all_non_op_selectives, bool) + + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, # type: ignore[arg-type] + build_features, # type: ignore[arg-type] + include_all_non_op_selectives, + ) + + @staticmethod + def from_yaml_str(config_contents: str) -> SelectiveBuilder: + contents = yaml.safe_load(config_contents) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_yaml_path(config_path: str) -> SelectiveBuilder: + with open(config_path) as f: + contents = yaml.safe_load(f) + return SelectiveBuilder.from_yaml_dict(contents) + + @staticmethod + def from_legacy_op_registration_allow_list( + allow_list: set[str], is_root_operator: bool, is_used_for_training: bool + ) -> SelectiveBuilder: + operators = {} + for op in allow_list: + operators[op] = { + "name": op, + "is_root_operator": is_root_operator, + "is_used_for_training": is_used_for_training, + "include_all_overloads": True, + } + return SelectiveBuilder.from_yaml_dict( + { + "operators": operators, + "include_all_non_op_selectives": True, + } + ) + + def is_operator_selected(self, name: str) -> bool: + if self.include_all_operators: + return True + + if name in self.operators: + return True + name = strip_operator_overload_name(name) + return name in self.operators and self.operators[name].include_all_overloads + + def is_native_function_selected(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected(op_name) + + def is_operator_selected_for_training(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + not_training_op = SelectiveBuildOperator( + name="", + is_root_operator=False, + is_used_for_training=False, + include_all_overloads=False, + _debug_info=None, + ) + op = not_training_op + if name in self.operators: + op = self.operators[name] + + name = strip_operator_overload_name(name) + base_op = not_training_op + if name in self.operators: + base_op = self.operators[name] + + return op.is_used_for_training or ( + base_op.include_all_overloads and base_op.is_used_for_training + ) + + def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: + op_name = op_name_from_native_function(func) + return self.is_operator_selected_for_training(op_name) + + def is_root_operator(self, name: str) -> bool: + if not self.is_operator_selected(name): + return False + if self.include_all_operators: + return True + + if name in self.operators: + op: SelectiveBuildOperator = self.operators[name] + return op.is_root_operator + name = strip_operator_overload_name(name) + if name not in self.operators: + return False + base_op: SelectiveBuildOperator = self.operators[name] + return base_op.include_all_overloads and base_op.is_root_operator + + def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool: + if self.include_all_operators or self.include_all_non_op_selectives: + return True + + return ( + kernel_tag in self.kernel_metadata + and dtype in self.kernel_metadata[kernel_tag] + ) + + def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]: + """ + Return a list of kernel keys that cover the used ops + """ + # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used. + if op_name not in self.et_kernel_metadata: + return kernel_key if self.include_all_operators else [] + # Otherwise, only return the specific kernel keys. + + result_set = set() + + for model_kernel_keys in self.et_kernel_metadata[op_name]: + key_found = False + for key in kernel_key: + # Don't compare the version for now + if ( + key != "default" + and key.split("/")[1] == model_kernel_keys.split("/")[1] + ): + result_set.add(key) + key_found = True + break + if not key_found: + if "default" not in kernel_key: + raise Exception("Missing kernel for the model") # noqa: TRY002 + else: + result_set.add("default") + + return list(result_set) + + def to_dict(self) -> dict[str, object]: + ret: dict[str, object] = { + "include_all_non_op_selectives": self.include_all_non_op_selectives, + "include_all_operators": self.include_all_operators, + } + operators = {} + for op_name, op in self.operators.items(): + operators[op_name] = op.to_dict() + ret["operators"] = operators + + if self._debug_info is not None: + ret["debug_info"] = sorted(self._debug_info) + + ret["kernel_metadata"] = { + k: sorted(v) for (k, v) in self.kernel_metadata.items() + } + + ret["et_kernel_metadata"] = self.et_kernel_metadata + + ret["custom_classes"] = sorted(self.custom_classes) + + ret["build_features"] = sorted(self.build_features) + + return ret + + +def merge_kernel_metadata( + lhs: dict[str, list[str]], + rhs: dict[str, list[str]], +) -> dict[str, list[str]]: + kernel_metadata: dict[str, list[str]] = {} + for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): + dtypes_copy = set(dtypes) + if tag_name in kernel_metadata: + dtypes_copy |= set(kernel_metadata[tag_name]) + + kernel_metadata[tag_name] = list(dtypes_copy) + + return kernel_metadata + + +def merge_et_kernel_metadata( + lhs: dict[str, list[str]], + rhs: dict[str, list[str]], +) -> dict[str, list[str]]: + merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set) + for op in list(lhs.keys()) + list(rhs.keys()): + merge_et_kernel_metadata[op].update(lhs.get(op, [])) + merge_et_kernel_metadata[op].update(rhs.get(op, [])) + + return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()} + + +def combine_selective_builders( + lhs: SelectiveBuilder, rhs: SelectiveBuilder +) -> SelectiveBuilder: + include_all_operators = lhs.include_all_operators or rhs.include_all_operators + debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) + operators = merge_operator_dicts(lhs.operators, rhs.operators) + kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata) + et_kernel_metadata = merge_et_kernel_metadata( + lhs.et_kernel_metadata, rhs.et_kernel_metadata + ) + include_all_non_op_selectives = ( + lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives + ) + custom_classes = lhs.custom_classes.union(rhs.custom_classes) + build_features = lhs.build_features.union(rhs.build_features) + return SelectiveBuilder( + include_all_operators, + debug_info, + operators, + kernel_metadata, + et_kernel_metadata, + custom_classes, + build_features, + include_all_non_op_selectives, + ) + + +def op_name_from_native_function(f: NativeFunction) -> str: + # This was originally read from the 'operator_name_with_overload' field in the + # declaration dict, which was the part before the first '(' in 'schema_string'. + return f"{f.namespace}::{f.func.name}" diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/.gitignore b/Prism/LLaDA2mini/LLaDA2mini_Baseline/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..06fcf0c6ecee82cc5fe808575c4af69b9527fdb6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/.gitignore @@ -0,0 +1,210 @@ +*.jsonl +*.json + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/LICENSE b/Prism/LLaDA2mini/LLaDA2mini_Baseline/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0f363b42d00f2a291c617c43e6fc3a9f142729be --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 preordinary + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c50ad3edd2cb1dea76048d416624a7c7db7c3209 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/__init__.py @@ -0,0 +1,7 @@ +import logging +import os + +from .evaluator import evaluate, simple_evaluate + + +__version__ = "0.4.9" diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/__main__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9d7ccf4a174c7c41fab5148a8851353a7e7eb6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/__main__.py @@ -0,0 +1,527 @@ +import argparse +import json +import logging +import os +import sys +from functools import partial +from pathlib import Path +from typing import Union + +from dllm_eval import evaluator, utils +from dllm_eval.evaluator import request_caching_arg_to_dict +from dllm_eval.loggers import EvaluationTracker, WandbLogger +from dllm_eval.tasks import TaskManager +from dllm_eval.utils import ( + handle_non_serializable, + make_table, + simple_parse_args_string, +) + + +def try_parse_json(value: str) -> Union[str, dict, None]: + if value is None: + return None + try: + return json.loads(value) + except json.JSONDecodeError: + if "{" in value: + raise argparse.ArgumentTypeError( + f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings." + ) + return value + + +def _int_or_none_list_arg_type( + min_len: int, max_len: int, defaults: str, value: str, split_char: str = "," +): + def parse_value(item): + item = item.strip().lower() + if item == "none": + return None + try: + return int(item) + except ValueError: + raise argparse.ArgumentTypeError(f"{item} is not an integer or None") + + items = [parse_value(v) for v in value.split(split_char)] + num_items = len(items) + + if num_items == 1: + # Makes downstream handling the same for single and multiple values + items = items * max_len + elif num_items < min_len or num_items > max_len: + raise argparse.ArgumentTypeError( + f"Argument requires {max_len} integers or None, separated by '{split_char}'" + ) + elif num_items != max_len: + logging.warning( + f"Argument requires {max_len} integers or None, separated by '{split_char}'. " + "Missing values will be filled with defaults." + ) + default_items = [parse_value(v) for v in defaults.split(split_char)] + items.extend( + default_items[num_items:] + ) # extend items list with missing defaults + + return items + + +def check_argument_types(parser: argparse.ArgumentParser): + """ + Check to make sure all CLI args are typed, raises error if not + """ + for action in parser._actions: + if action.dest != "help" and not action.const: + if action.type is None: + raise ValueError( + f"Argument '{action.dest}' doesn't have a type specified." + ) + else: + continue + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`" + ) + parser.add_argument( + "--tasks", + "-t", + default=None, + type=str, + metavar="task1,task2", + help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above", + ) + parser.add_argument( + "--model_args", + "-a", + default="", + type=try_parse_json, + help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""", + ) + parser.add_argument( + "--num_fewshot", + "-f", + type=int, + default=None, + metavar="N", + help="Number of examples in few-shot context", + ) + parser.add_argument( + "--batch_size", + "-b", + type=str, + default=1, + metavar="auto|auto:N|N", + help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=None, + metavar="N", + help="Maximal batch size to try with --batch_size auto.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to use (e.g. cuda, cuda:0, cpu).", + ) + parser.add_argument( + "--output_path", + "-o", + default=None, + type=str, + metavar="DIR|DIR/file.json", + help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", + ) + parser.add_argument( + "--limit", + "-L", + type=float, + default=None, + metavar="N|0 argparse.Namespace: + check_argument_types(parser) + return parser.parse_args() + + +def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: + if not args: + # we allow for args to be passed externally, else we parse them ourselves + parser = setup_parser() + args = parse_eval_args(parser) + + if args.wandb_args: + wandb_args_dict = simple_parse_args_string(args.wandb_args) + wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args) + wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict) + + utils.setup_logging(args.verbosity) + eval_logger = logging.getLogger(__name__) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # update the evaluation tracker args with the output path and the HF token + if args.output_path: + args.hf_hub_log_args += f",output_path={args.output_path}" + if os.environ.get("HF_TOKEN", None): + args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}" + evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args) + evaluation_tracker = EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError( + "Specify --output_path if providing --log_samples or --predict_only" + ) + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + eval_logger.info(f"Including path: {args.include_path}") + metadata = ( + simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args + if isinstance(args.model_args, dict) + else {} + ) | ( + args.metadata + if isinstance(args.metadata, dict) + else simple_parse_args_string(args.metadata) + ) + + task_manager = TaskManager(include_path=args.include_path, metadata=metadata) + + if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples: + eval_logger.warning( + "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub." + ) + + if args.limit: + eval_logger.warning( + " --limit SHOULD ONLY BE USED FOR TESTING." + "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." + ) + if args.samples: + assert args.limit is None, ( + "If --samples is not None, then --limit must be None." + ) + if (samples := Path(args.samples)).is_file(): + args.samples = json.loads(samples.read_text()) + else: + args.samples = json.loads(args.samples) + + if args.tasks is None: + eval_logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + eval_logger.error( + f"Tasks were not found: {missing}\n" + f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues." + ) + + # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args + if args.trust_remote_code: + eval_logger.info( + "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`" + ) + # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, + # because it's already been determined based on the prior env var before launching our + # script--`datasets` gets imported by dllm_eval internally before these lines can update the env. + import datasets + + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True + + args.model_args = args.model_args + ",trust_remote_code=True" + ( + eval_logger.info(f"Selected Tasks: {task_names}") + if eval_logger.getEffectiveLevel() >= logging.INFO + else print(f"Selected Tasks: {task_names}") + ) + + request_caching_args = request_caching_arg_to_dict( + cache_requests=args.cache_requests + ) + + results = evaluator.simple_evaluate( + model=args.model, + model_args=args.model_args, + tasks=task_names, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + max_batch_size=args.max_batch_size, + device=args.device, + use_cache=args.use_cache, + limit=args.limit, + samples=args.samples, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + if results is not None: + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps( + results, indent=2, default=handle_non_serializable, ensure_ascii=False + ) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging + if args.wandb_args: + try: + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated( + results=results, samples=samples if args.log_samples else None + ) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples( + task_name=task_name, samples=samples[task_name] + ) + + if ( + evaluation_tracker.push_results_to_hub + or evaluation_tracker.push_samples_to_hub + ): + evaluation_tracker.recreate_metadata_card() + + print( + f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + + if args.wandb_args: + # Tear down wandb run once all the logging is done. + wandb_logger.run.finish() + + +if __name__ == "__main__": + cli_evaluate() diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/filter.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..bddbf3ab8d1bcbba804f9790ef0290d437bcde69 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/filter.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Iterable, List, Union + +from dllm_eval.api.instance import Instance + + +class Filter(ABC): + """ + Filter classes operate on a per-task level. + They take all model outputs (`instance.resps` for all `task.instances`) + across all instances of a task, and perform operations. + In a single run, one can configure any number of separate filters or lists of filters. + + """ + + def __init__(self, **kwargs) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + @abstractmethod + def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: + """ + Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. + Should return the list of (filtered) response lists *in the same order as they were input*, e.g. + if pass in [, ] should return + [, ] + """ + return resps + + +@dataclass +class FilterEnsemble: + """ + FilterEnsemble creates a pipeline applying multiple filters. + Its intended usage is to stack multiple post-processing steps in order. + `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each + pipeline separately. + """ + + name: str + filters: List[Callable[[], Filter]] + + def apply(self, instances: List[Instance]) -> None: + resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) + resps, docs = list(resps), list(docs) + + for f in self.filters: + # apply filters in sequence + resps = f().apply(resps, docs) + + # add the end results after filtering to filtered_requests of their respective source instances. + # has key `self.name`: each FilterEnsemble applied in a given run should use a different name. + for inst, resp in zip(instances, resps): + inst.filtered_resps[self.name] = resp diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/group.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/group.py new file mode 100644 index 0000000000000000000000000000000000000000..0c60739bbd26c79ecab91f54240798b2ae9e3313 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/group.py @@ -0,0 +1,115 @@ +import abc +from dataclasses import asdict, dataclass +from inspect import getsource +from typing import Any, Callable, List, Optional, Union + + +@dataclass +class AggMetricConfig(dict): + metric: Optional[str] = None + aggregation: Optional[str] = "mean" + weight_by_size: Optional[str] = False + # list of filter names which should be incorporated into the aggregated metric. + filter_list: Optional[Union[str, list]] = "none" + + def __post_init__(self): + if self.aggregation != "mean" and not callable(self.aggregation): + raise ValueError( + f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'." + ) + + if isinstance(self.filter_list, str): + self.filter_list = [self.filter_list] + + +@dataclass +class GroupConfig(dict): + group: Optional[str] = None + group_alias: Optional[str] = None + task: Optional[Union[str, list]] = None + aggregate_metric_list: Optional[ + Union[List[AggMetricConfig], AggMetricConfig, dict] + ] = None + metadata: Optional[dict] = ( + None # by default, not used in the code. allows for users to pass arbitrary info to tasks + ) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + return setattr(self, item, value) + + def __post_init__(self): + if self.aggregate_metric_list is not None: + if isinstance(self.aggregate_metric_list, dict): + self.aggregate_metric_list = [self.aggregate_metric_list] + + self.aggregate_metric_list = [ + AggMetricConfig(**item) if isinstance(item, dict) else item + for item in self.aggregate_metric_list + ] + + def to_dict(self, keep_callable: bool = False) -> dict: + """dumps the current config as a dictionary object, as a printable format. + null fields will not be printed. + Used for dumping results alongside full task configuration + + :return: dict + A printable dictionary version of the TaskConfig object. + + # TODO: should any default value in the TaskConfig not be printed? + """ + cfg_dict = asdict(self) + # remove values that are `None` + for k, v in list(cfg_dict.items()): + if callable(v): + cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) + return cfg_dict + + def serialize_function( + self, value: Union[Callable, str], keep_callable=False + ) -> Union[Callable, str]: + """Serializes a given function or string. + + If 'keep_callable' is True, the original callable is returned. + Otherwise, attempts to return the source code of the callable using 'getsource'. + """ + if keep_callable: + return value + else: + try: + return getsource(value) + except (TypeError, OSError): + return str(value) + + +class ConfigurableGroup(abc.ABC): + def __init__( + self, + config: Optional[dict] = None, + ) -> None: + self._config = GroupConfig(**config) + + @property + def group(self): + return self._config.group + + @property + def group_alias(self): + return self._config.group_alias + + @property + def version(self): + return self._config.version + + @property + def config(self): + return self._config.to_dict() + + @property + def group_name(self) -> Any: + return self._config.group + + def __repr__(self): + return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})" diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/instance.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/instance.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c6afa0644e729ba441728c72a2469fdad07b8f --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/instance.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass, field +from typing import Literal, Optional, Tuple + + +OutputType = Literal[ + "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" +] + + +@dataclass +class Instance: + request_type: OutputType + doc: dict + arguments: tuple + idx: int + metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( + default_factory=lambda: (None, None, None) + ) + resps: list = field(default_factory=list) + filtered_resps: dict = field(default_factory=dict) + + # initialized after init + task_name: Optional[str] = None + doc_id: Optional[int] = None + repeats: Optional[int] = None + + def __post_init__(self) -> None: + # unpack metadata field + self.task_name, self.doc_id, self.repeats = self.metadata + + @property + def args(self): + """ + Returns (string,) where `string` is the string to calculate loglikelihood over + """ + return ( + self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/metrics.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2aff6ce92a154a05df3d0bb7d28e09071cd12fbc --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/metrics.py @@ -0,0 +1,578 @@ +import logging +import math +import random +import re +import string +from collections.abc import Iterable +from typing import List + +import numpy as np +import sacrebleu + +from dllm_eval.api.registry import register_aggregation, register_metric + + +eval_logger = logging.getLogger(__name__) + + +# Register Aggregations First +@register_aggregation("bypass") +def bypass_agg(arr): + return 999 + + +@register_aggregation("nanmean") +def nanmean(arr): + if len(arr) == 0 or all(np.isnan(arr)): + return np.nan + return np.nanmean(arr) + + +@register_aggregation("mean") +def mean(arr): + return sum(arr) / len(arr) + + +@register_aggregation("median") +def median(arr): + return arr[len(arr) // 2] + + +# Certain metrics must be calculated across all documents in a benchmark. +# We use them as aggregation metrics, paired with no-op passthrough metric fns. +@register_aggregation("perplexity") +def perplexity(items): + return math.exp(-mean(items)) + + +@register_aggregation("weighted_perplexity") +def weighted_perplexity(items): + return math.exp(-weighted_mean(items)) + + +@register_aggregation("bits_per_byte") +def bits_per_byte(items): + return -weighted_mean(items) / math.log(2) + + +@register_aggregation("f1") +def f1_score(items): + from sklearn.metrics import f1_score + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + fscore = f1_score(golds, preds) + + return np.max(fscore) + + +@register_aggregation("matthews_corrcoef") +def matthews_corrcoef(items): + from sklearn.metrics import matthews_corrcoef + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + return matthews_corrcoef(golds, preds) + + +@register_aggregation("bleu") +def bleu(items): + """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric + for evaluating a generated sentence to a reference sentence. It counts matching + n-grams in the candidate translation to n-grams in the reference text, where + 1-gram or unigram would be each token and a bigram comparison would be each + word pair. The comparison is made regardless of word order + Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ + Paper: https://www.aclweb.org/anthology/P02-1040/ + + Higher is better + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_bleu(preds, refs).score + + +@register_aggregation("chrf") +def chrf(items): + """chrF++ is a tool for automatic evaluation of machine translation output + based on character n-gram precision and recall enhanced with word n-grams. + Source: https://github.com/m-popovic/chrF + Paper: https://www.aclweb.org/anthology/W15-3049.pdf + + Higher is better # TODO I think + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_chrf(preds, refs).score + + +@register_aggregation("ter") +def ter(items): + """Translation Error Rate is an error metric for machine translation that + measures the number of edits required to change a system output into one + of the references + Source: http://www.cs.umd.edu/~snover/tercom/ + Paper: http://mt-archive.info/AMTA-2006-Snover.pdf + + Lower is better + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_ter(preds, refs).score + + +@register_aggregation("brier_score") +def brier_score(items): # This is a passthrough function + gold, predictions = list(zip(*items)) + bs, num_class = np.array(predictions).shape + + gold = list(gold) + gold_one_hot = np.eye(num_class)[gold] + return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1)) + + +@register_metric( + metric="brier_score", + higher_is_better=False, + output_type=["multiple_choice"], + aggregation="brier_score", +) +def brier_score_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice"], + aggregation="mean", +) +def acc_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_norm", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice"], + aggregation="mean", +) +def acc_norm_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_mutual_info", + higher_is_better=True, + output_type="multiple_choice", + aggregation="mean", +) +def acc_mutual_info_fn(items): # This is a passthrough function + return items + + +### the code used in the `exact_match_hf_evaluate` function is ported from +### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py +### which is under the apache license. + +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +def exact_match_hf_evaluate( + predictions, + references, + regexes_to_ignore=None, + ignore_case=False, + ignore_punctuation=False, + ignore_numbers=False, +): + if regexes_to_ignore is not None: + for s in regexes_to_ignore: + predictions = np.array([re.sub(s, "", x) for x in predictions]) + references = np.array([re.sub(s, "", x) for x in references]) + else: + predictions = np.asarray(predictions) + references = np.asarray(references) + + if ignore_case: + predictions = np.char.lower(predictions) + references = np.char.lower(references) + + if ignore_punctuation: + repl_table = string.punctuation.maketrans("", "", string.punctuation) + predictions = np.char.translate(predictions, table=repl_table) + references = np.char.translate(references, table=repl_table) + + if ignore_numbers: + repl_table = string.digits.maketrans("", "", string.digits) + predictions = np.char.translate(predictions, table=repl_table) + references = np.char.translate(references, table=repl_table) + + score_list = predictions == references + + return {"exact_match": np.mean(score_list)} + + +### + + +@register_metric( + metric="exact_match", + higher_is_better=True, + output_type="generate_until", + aggregation="mean", +) +def exact_match_fn(**kwargs): + return exact_match_hf_evaluate(**kwargs) + + +@register_metric( + metric="perplexity", + higher_is_better=False, + output_type="loglikelihood", + aggregation="perplexity", +) +def perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="word_perplexity", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="weighted_perplexity", +) +def word_perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="byte_perplexity", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="weighted_perplexity", +) +def byte_perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="bits_per_byte", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="bits_per_byte", +) +def bits_per_byte_fn(items): # This is a passthrough function + return items + + +def pop_stddev(arr): + mu = mean(arr) + return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) + + +def sample_stddev(arr): + mu = mean(arr) + return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) + + +def mean_stderr(arr): + return sample_stddev(arr) / math.sqrt(len(arr)) + + +@register_metric( + metric="bypass", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice", "generate_until"], + aggregation="bypass", +) +def bypass(items): + return None + + +@register_metric( + metric="mcc", + higher_is_better=True, + output_type="multiple_choice", + aggregation="matthews_corrcoef", +) +def mcc_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="f1", + higher_is_better=True, + output_type="multiple_choice", + aggregation="f1", +) +def f1_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="bleu", + higher_is_better=True, + output_type="generate_until", + aggregation="bleu", +) +def bleu_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="chrf", + higher_is_better=True, + output_type="generate_until", + aggregation="chrf", +) +def chrf_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="ter", + higher_is_better=True, + output_type="generate_until", + aggregation="ter", +) +def ter_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_all", + higher_is_better=True, + output_type="loglikelihood", + aggregation="mean", +) +def acc_all(items): + # Only count as correct if all answers are labeled correctly for each question + question_scoring_dict = {} + preds = list(zip(*items))[0] + docs = list(zip(*items))[1] + + for doc, pred in zip(docs, preds): + paragraph_id = doc["idx"]["paragraph"] + question_id = doc["idx"]["question"] + if (paragraph_id, question_id) not in question_scoring_dict: + question_scoring_dict[(paragraph_id, question_id)] = [] + + gold_label = doc["label"] == 1 + + question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) + acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) + return acc + + +def acc_all_stderr(items): + # Only count as correct if all answers are labeled correctly for each question + question_scoring_dict = {} + preds = list(zip(*items))[0] + docs = list(zip(*items))[1] + + for doc, pred in zip(docs, preds): + question_id = doc["idx"]["question"] + if question_id not in question_scoring_dict: + question_scoring_dict[question_id] = [] + + gold_label = doc["label"] == 1 + question_scoring_dict[question_id].append(gold_label == pred) + + acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) + return acc + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + """Compute max metric between prediction and each ground truth.""" + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def weighted_mean(items): + a, b = zip(*items) + return sum(a) / sum(b) + + +def is_non_str_iterable(obj): + return isinstance(obj, Iterable) and not isinstance(obj, str) + + +def _sacreformat(refs, preds): + """Format refs and preds for sacrebleu corpus calculation. It is very particular""" + # Sacrebleu expects (List[str], List[List[str]) + # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) + + # Note [ref1_stream] is the first reference for each pred. + # So lists are size N and (M, N) for N preds and M possible refs for each pred + # This is a different order of dimensions that I would expect + + # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds + # Must become List[List[str]] with the inner list corresponding to preds + if not is_non_str_iterable(refs): + refs = list(refs) + if not is_non_str_iterable(refs[0]): + refs = [[ref] for ref in refs] + refs = list(zip(*refs)) + # Note the number of refs in each ref list much match the number of preds + + # We expect preds to be List[str] or List[List[str]]. Must become List[str] + if not is_non_str_iterable(preds): + preds = list(preds) + if is_non_str_iterable(preds[0]): + assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" + preds = [pred[0] for pred in preds] + + return refs, preds + + +# stderr stuff + + +class _bootstrap_internal: + def __init__(self, f, n) -> None: + self.f = f + self.n = n + + def __call__(self, v): + i, xs = v + rnd = random.Random() + rnd.seed(i) + res = [] + for _ in range(self.n): + res.append(self.f(rnd.choices(xs, k=len(xs)))) + return res + + +def bootstrap_stderr(f, xs, iters): + import multiprocessing as mp + + pool = mp.Pool(mp.cpu_count()) + # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something + # equivalent to stderr calculated without Bessel's correction in the stddev. + # Unfortunately, I haven't been able to figure out what the right correction is + # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but + # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) + # Thankfully, shouldn't matter because our samples are pretty big usually anyways + res = [] + chunk_size = min(1000, iters) + from tqdm import tqdm + + print("bootstrapping for stddev:", f.__name__) + for bootstrap in tqdm( + pool.imap( + _bootstrap_internal(f, chunk_size), + [(i, xs) for i in range(iters // chunk_size)], + ), + total=iters // chunk_size, + ): + # sample w replacement + res.extend(bootstrap) + + pool.close() + return sample_stddev(res) + + +def stderr_for_metric(metric, bootstrap_iters: int): + if bootstrap_iters <= 0: + # return no function (don't compute stderr) if bootstrap iters = 0 + return None + + bootstrappable = [ + median, + matthews_corrcoef, + f1_score, + perplexity, + bleu, + chrf, + ter, + nanmean, + ] + + if metric in bootstrappable: + return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) + + stderr = {mean: mean_stderr, acc_all: acc_all_stderr} + + return stderr.get(metric, None) + + +def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): + # Used to aggregate bootstrapped stderrs across subtasks in a group, + # when we are weighting by the size of each subtask. + # + + assert len(stderrs) == len(sizes) + + # formula source: https://en.wikipedia.org/wiki/Pooled_variance + # and: https://stats.stackexchange.com/a/4841331 + # this empirically seems to match running `stderr_for_metric` on all instances + # from the subtasks concatenated with each other. + pooled_sample_var = ( + sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)]) + ) / (sum(sizes) - len(sizes)) + + return np.sqrt(pooled_sample_var / sum(sizes)) + + +def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): + assert metrics is not None, ( + "Need to pass a list of each subtask's metric for this stderr aggregation" + ) + assert len(stderrs) == len(sizes) and len(sizes) == len(metrics) + + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation. + # This formula depends on sample means. + # removed because it seems to give erroneously huge stderrs for groupings of tasks + # and does not seem to match up with bootstrap-calculated stderrs for groups. + + ### don't use this unless a statistician has told you it's the right thing to do ### + + # accumulators: we'll aggregate pairwise N - 1 times + variance = stderrs[0] ** 2 + curr_size = sizes[0] + curr_score = metrics[0] + + for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]): + curr_score = ((curr_score * curr_size) + (score * size)) / ( + curr_size + size + ) # NOTE: this assumes our aggregation fn is "mean" + + variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / ( + curr_size + size - 1 + ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * ( + curr_score - score + ) ** 2 + + return np.sqrt(variance) + + +def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): + # A helper function that is used to aggregate + # subtask scores cross-task. + # TODO: does not hold for non-mean aggregations + if not weight_by_size: + sizes = [1] * len(sizes) + + assert len(metrics) == len(sizes) + + return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/model.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9364a9312d78c1029e5edf38d61f192afca91334 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/model.py @@ -0,0 +1,493 @@ +import abc +import hashlib +import json +import logging +import os +from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union + +import transformers +from sqlitedict import SqliteDict +from tqdm import tqdm + +from dllm_eval import utils + + +eval_logger = logging.getLogger(__name__) + +T = TypeVar("T", bound="LM") + + +class LM(abc.ABC): + def __init__(self) -> None: + """Defines the interface that should be implemented by all LM subclasses. + LMs are assumed to take text (strings) as input and yield strings as output + (inputs/outputs should be tokenization-agnostic.) + + """ + # set rank and world size to a single process, by default. + self._rank = 0 + self._world_size = 1 + self.cache_hook = CacheHook(None) + + @abc.abstractmethod + def loglikelihood(self, requests) -> List[Tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + pass + + @abc.abstractmethod + def loglikelihood_rolling(self, requests) -> List[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: BOS/EOS + Max context length: 4 + Resulting input/prediction pairs: + + INPUT: BOS 0 1 2 + PRED: 0 1 2 3 + + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + + INPUT: 5 6 7 8 + PRED: 8 9 + + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the BOS/EOS token. + Can also be overridden for custom cases by `prefix_token_id`. + """ + pass + + # TODO: Add an optional max length + @abc.abstractmethod + def generate_until(self, requests) -> List[str]: + """Generate greedily until a stopping sequence + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs). + context: str + Context string + gen_kwargs: dict + A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc. + :return: list[str] + A list of model generated continuations. + continuation: str + The generated continuation. + """ + pass + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt=True + ) -> str: + """ + Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. + + :param chat_history: list[dict[str, str]] + A list of dictionaries with keys 'role' and 'content'. + Values are strings representing the role name and the content of the message, respectively. + :param add_generation_prompt: bool + Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message. + :return: str + A string representing the chat history in a format that can be used as input to the LM. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type." + ) + + @classmethod + def create_from_arg_string( + cls: Type[T], arg_string: str, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given argument string and additional config. + + Parameters: + - arg_string: A string containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + additional_config = {} if additional_config is None else additional_config + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + @classmethod + def create_from_arg_obj( + cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given arg_obj + + Parameters: + - arg_obj: A dict containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + + additional_config = {} if additional_config is None else additional_config + additional_config = { + k: v for k, v in additional_config.items() if v is not None + } + + return cls(**arg_dict, **additional_config) + + @property + def rank(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._rank + + @property + def world_size(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._world_size + + @property + def tokenizer_name(self) -> str: + """Must be defined for LM subclasses which implement Chat Templating. + Should return the name of the tokenizer or chat template used. + Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'tokenizer_name' property." + ) + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """Returns the chat template structure for user/assistant messages if a template is provided. + This method is intended to be overridden in a subclass to define a specific chat template format. + For models that do not support chat templates, this method returns None by default. + """ + + return "" + + def set_cache_hook(self, cache_hook) -> None: + self.cache_hook = cache_hook + + +### SQLite-based caching of LM responses +def hash_args(attr, args): + dat = json.dumps([attr] + list(args)) + return hashlib.sha256(dat.encode("utf-8")).hexdigest() + + +class CacheHook: + def __init__(self, cachinglm) -> None: + if cachinglm is None: + self.dbdict = None + return + + self.dbdict = cachinglm.dbdict + + def add_partial(self, attr, req, res) -> None: + if self.dbdict is None: + return + hsh = hash_args(attr, req) + self.dbdict[hsh] = res + + +class CachingLM: + def __init__(self, lm, cache_db) -> None: + """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. + + :param lm: LM + Underlying LM + :param cache_db: str + Path to cache db + """ + self.lm = lm + self.cache_db = cache_db + if os.path.dirname(cache_db): + os.makedirs(os.path.dirname(cache_db), exist_ok=True) + self.dbdict = SqliteDict(cache_db, autocommit=True) + + # add hook to lm + lm.set_cache_hook(self.get_cache_hook()) + + def __getattr__(self, attr: str): + lm_attr = getattr(self.lm, attr) + if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]: + eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") + return lm_attr + + def fn(requests): + res = [] + remaining_reqs = [] + warned = False + # figure out which ones are cached and which ones are new + eval_logger.info( + f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." + ) + for req in tqdm(requests, desc="Checking cached requests"): + hsh = hash_args(attr, req.args) + if attr == "generate_until" and req.args[1].get("do_sample", False): + # when we are doing non-greedy generation, don't use the cache + # (else every "randomly sampled" generation would be identical for repeats > 1). + if not warned: + eval_logger.warning( + f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." + ) + warned = True + res.append(None) + remaining_reqs.append(req) + elif hsh in self.dbdict: + ob = self.dbdict[hsh] + + assert ob is not None + + res.append(ob) + else: + res.append(None) + remaining_reqs.append(req) + eval_logger.info( + f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" + ) + if remaining_reqs: + # actually run the LM on the requests that do not have cached results + rem_res = getattr(self.lm, attr)(remaining_reqs) + else: + rem_res = [] + + # stick the new ones back into the list and also cache any of the new ones + resptr = 0 + for req, r in zip(remaining_reqs, rem_res): + while res[resptr] is not None: + resptr += 1 + + res[resptr] = r + + # caching + hsh = hash_args(attr, req.args) + self.dbdict[hsh] = r + self.dbdict.commit() + + return res + + return fn + + def get_cache_hook(self): + return CacheHook(self) + + +class TemplateLM(LM): + """ + A class acting as intermediary between the LM base class + and boilerplate often included in other LM subclasses. + """ + + tokenizer = None + + @property + @abc.abstractmethod + def eot_token_id(self): + pass + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + return self.eot_token_id + + @abc.abstractmethod + def tok_encode(self, string: str, **kwargs) -> List[int]: + """ + Tokenize a string using the model's tokenizer and return a list of token IDs. + """ + pass + + @abc.abstractmethod + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + pass + + def _encode_pair( + self, context: str, continuation: str + ) -> Tuple[List[int], List[int]]: + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + model_class = getattr(self, "AUTO_MODEL_CLASS", None) + + if model_class == transformers.AutoModelForSeq2SeqLM: + context_enc = self.tok_encode(context) + continuation_enc = self.tok_encode(continuation, add_special_tokens=False) + else: + whole_enc = self.tok_encode(context + continuation) + context_enc = self.tok_encode(context) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood( + self, requests, disable_tqdm: bool = False + ) -> List[Tuple[float, bool]]: + new_reqs = [] + for context, continuation in [req.args for req in requests]: + if context == "": + # BOS or EOS as context + context_enc, continuation_enc = ( + [self.prefix_token_id], + self.tok_encode(continuation), + ) + else: + context_enc, continuation_enc = self._encode_pair(context, continuation) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) + + @abc.abstractmethod + def loglikelihood_rolling( + self, requests, disable_tqdm: bool = False + ) -> List[float]: + pass + + @abc.abstractmethod + def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: + pass + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """ + Set and get the appropriate chat template for the model. + This method sets the tokenizer's chat_template and returns the template string for reproducibility. + + The template selection logic is adapted from the Transformers library's `apply_chat_template` + method in the Tokenizer class. The original implementation can be found at: + https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687 + + This method ensures that the right template is chosen based on the following: + 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string. + 1. If the model's tokenizer has multiple templates: + a. Use the specified template if it exists in the dictionary. + b. Use the default template from the list if no specific template is provided. + c. Raise an error if no default template exists and no specific template is provided. + 2. If the model's tokenizer has a single template or no template: + a. Use the tokenizer's chat template if available. + b. Fall back to the default chat template if no tokenizer chat template exists. + + Args: + chat_template (Union[bool, str]): Specifies the chat template to use. + - If False or None, no template is applied. + - If True, the default or only available template is used. + - If a string, the template with the matching name is used. + + Returns: + Optional[str]: The selected chat template, or None if no template is applied. + """ + if self.tokenizer is None: + return "" + + if chat_template is False or chat_template is None: + eval_logger.warning( + "model.chat_template was called with the chat_template set to False or None. " + "Therefore no chat template will be applied. Make sure this is an intended behavior." + ) + return None + + # Convert boolean chat_template to None to ensure compatibility with the adapted logic + if isinstance(chat_template, bool): + chat_template = None + using_default_template = False + + # First, handle the cases when the model has a dict of multiple templates + try: + template = ( + self.tokenizer.chat_template or self.tokenizer.default_chat_template + ) + except AttributeError: + return None + + if isinstance(template, dict): + using_default_dict = self.tokenizer.chat_template is None + + if chat_template is not None: + if chat_template in template: + selected_template = template[chat_template] + if using_default_dict: + using_default_template = True + else: + raise ValueError( + f"The specified chat template '{chat_template}' is not available. " + f"Available template names are {sorted(template.keys())}." + ) + else: + # If user didn't pass a chat template, use the default template from the dict + if "default" in template: + selected_template = template["default"] + using_default_template = True + else: + raise ValueError( + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template.keys())}." + ) + + # Cases when the model has a single template or no template + else: + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template + if isinstance(chat_template, str): + eval_logger.warning( + "Chat template name provided, but the tokenizer's chat template is not a dictionary. " + "Using the tokenizer's chat template or the default template instead." + ) + if self.tokenizer.chat_template is not None: + selected_template = self.tokenizer.chat_template + else: + selected_template = self.tokenizer.default_chat_template + using_default_template = True + + if using_default_template: + eval_logger.warning( + "No chat template is set for this tokenizer, falling back to a default class-level template. This is " + "very error-prone, because models are often trained with templates different from the class default! " + "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " + "point any code depending on them will stop working. We recommend setting a valid chat template before " + "then to ensure that this model continues working without issues." + ) + + return selected_template diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/registry.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2b2e415a0a19862a41bde307bbad2e6ba326f5 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/registry.py @@ -0,0 +1,196 @@ +import logging +from typing import Callable, Dict, Union + +import evaluate as hf_evaluate + +from dllm_eval.api.model import LM + + +eval_logger = logging.getLogger(__name__) + +MODEL_REGISTRY = {} + + +def register_model(*names): + # either pass a list or a single alias. + # function receives them as a tuple of strings + + def decorate(cls): + for name in names: + assert issubclass(cls, LM), ( + f"Model '{name}' ({cls.__name__}) must extend LM class" + ) + + assert name not in MODEL_REGISTRY, ( + f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." + ) + + MODEL_REGISTRY[name] = cls + return cls + + return decorate + + +def get_model(model_name): + try: + return MODEL_REGISTRY[model_name] + except KeyError: + raise ValueError( + f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" + ) + + +TASK_REGISTRY = {} +GROUP_REGISTRY = {} +ALL_TASKS = set() +func2task_index = {} + + +def register_task(name): + def decorate(fn): + assert name not in TASK_REGISTRY, ( + f"task named '{name}' conflicts with existing registered task!" + ) + + TASK_REGISTRY[name] = fn + ALL_TASKS.add(name) + func2task_index[fn.__name__] = name + return fn + + return decorate + + +def register_group(name): + def decorate(fn): + func_name = func2task_index[fn.__name__] + if name in GROUP_REGISTRY: + GROUP_REGISTRY[name].append(func_name) + else: + GROUP_REGISTRY[name] = [func_name] + ALL_TASKS.add(name) + return fn + + return decorate + + +OUTPUT_TYPE_REGISTRY = {} +METRIC_REGISTRY = {} +METRIC_AGGREGATION_REGISTRY = {} +AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} +HIGHER_IS_BETTER_REGISTRY = {} +FILTER_REGISTRY = {} + +DEFAULT_METRIC_REGISTRY = { + "loglikelihood": [ + "perplexity", + "acc", + ], + "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], + "multiple_choice": ["acc", "acc_norm"], + "generate_until": ["exact_match"], +} + + +def register_metric(**args): + # TODO: do we want to enforce a certain interface to registered metrics? + def decorate(fn): + assert "metric" in args + name = args["metric"] + + for key, registry in [ + ("metric", METRIC_REGISTRY), + ("higher_is_better", HIGHER_IS_BETTER_REGISTRY), + ("aggregation", METRIC_AGGREGATION_REGISTRY), + ]: + if key in args: + value = args[key] + assert value not in registry, ( + f"{key} named '{value}' conflicts with existing registered {key}!" + ) + + if key == "metric": + registry[name] = fn + elif key == "aggregation": + registry[name] = AGGREGATION_REGISTRY[value] + else: + registry[name] = value + + return fn + + return decorate + + +def get_metric(name: str, hf_evaluate_metric=False) -> Callable: + if not hf_evaluate_metric: + if name in METRIC_REGISTRY: + return METRIC_REGISTRY[name] + else: + eval_logger.warning( + f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." + ) + + try: + metric_object = hf_evaluate.load(name) + return metric_object.compute + except Exception: + eval_logger.error( + f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", + ) + + +def register_aggregation(name: str): + def decorate(fn): + assert name not in AGGREGATION_REGISTRY, ( + f"aggregation named '{name}' conflicts with existing registered aggregation!" + ) + + AGGREGATION_REGISTRY[name] = fn + return fn + + return decorate + + +def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: + try: + return AGGREGATION_REGISTRY[name] + except KeyError: + eval_logger.warning(f"{name} not a registered aggregation metric!") + + +def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: + try: + return METRIC_AGGREGATION_REGISTRY[name] + except KeyError: + eval_logger.warning(f"{name} metric is not assigned a default aggregation!") + + +def is_higher_better(metric_name) -> bool: + try: + return HIGHER_IS_BETTER_REGISTRY[metric_name] + except KeyError: + eval_logger.warning( + f"higher_is_better not specified for metric '{metric_name}'!" + ) + + +def register_filter(name): + def decorate(cls): + if name in FILTER_REGISTRY: + eval_logger.info( + f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}" + ) + FILTER_REGISTRY[name] = cls + return cls + + return decorate + + +def get_filter(filter_name: Union[str, Callable]) -> Callable: + try: + return FILTER_REGISTRY[filter_name] + except KeyError as e: + if callable(filter_name): + return filter_name + else: + eval_logger.warning(f"filter `{filter_name}` is not registered!") + raise e diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/samplers.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..969789ef2111dcb8ee3b7eed4c69d54572d6c302 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/samplers.py @@ -0,0 +1,232 @@ +import logging +import warnings +from functools import partial +from typing import TYPE_CHECKING, Iterable, Optional, Union + +import datasets + + +if TYPE_CHECKING: + from random import Random + + from dllm_eval.api.task import ConfigurableTask, Task + +eval_logger = logging.getLogger("lm-eval") + + +class ContextSampler: + def __init__( + self, + docs: list[dict], + task: Union["Task", "ConfigurableTask"], + fewshot_indices: Optional[Iterable] = None, + rnd: Optional["Random"] = None, + ) -> None: + self.rnd = rnd + if not self.rnd: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!" + ) + + self.task = task + self.config = task._config + + self.target_delimiter = self.config.target_delimiter + self.fewshot_delimiter = self.config.fewshot_delimiter + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_text", None) is not None + ): + self.doc_to_text = partial( + self.task.doc_to_text, + doc_to_text=self.config.fewshot_config.get("doc_to_text", None), + ) + else: + self.doc_to_text = self.task.doc_to_text + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_target", None) is not None + ): + self.doc_to_target = partial( + self.task.doc_to_target, + doc_to_target=self.config.fewshot_config.get("doc_to_target", None), + ) + else: + self.doc_to_target = self.task.doc_to_target + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_choice", None) is not None + ): + self.doc_to_choice = partial( + self.task.doc_to_choice, + doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None), + ) + else: + self.doc_to_choice = self.task.doc_to_choice + + self.docs = docs # HF dataset split, provided by task._fewshot_docs() + if fewshot_indices: # subset few-shot docs from + if not isinstance(self.docs, datasets.Dataset): + raise ValueError( + "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously" + ) + self.docs = self.docs.select(fewshot_indices) + + def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None): + # draw an extra fewshot sample if using same split as evaluating on + prefix = gen_prefix + " " if gen_prefix else "" + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = "" + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + if self.config.doc_to_choice is None or isinstance(doc_content, str): + labeled_examples += doc_content + else: + labeled_examples += self.doc_to_choice(doc)[doc_content] + + if doc_target != "": + if self.target_delimiter.isspace() and str(doc_target)[0].isspace(): + # TODO: add logger warn once here. + warnings.warn( + "Both target_delimiter and target start with a space. This may cause issues.", + Warning, + stacklevel=2, + ) + labeled_examples += self.target_delimiter + labeled_examples += prefix + labeled_examples += ( + str(doc_target[0]) + if isinstance(doc_target, list) + else doc_target + if self.config.doc_to_choice is None or isinstance(doc_target, str) + else str(self.doc_to_choice(doc)[doc_target]) + ) + labeled_examples += self.fewshot_delimiter + + return labeled_examples + + def get_chat_context( + self, + doc: dict, + num_fewshot: int, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ): + # TODO: Do we need any other delimiter + prefix = gen_prefix + " " if gen_prefix else "" + chat_history = [] + # draw an extra fewshot sample if using same split as evaluating on + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + if fewshot_as_multiturn: + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + chat_history.append( + { + "role": "user", + "content": doc_content + if self.config.doc_to_choice is None + or isinstance(doc_content, str) + else self.doc_to_choice(doc)[doc_content], + } + ) + chat_history.append( + { + "role": "assistant", + "content": prefix + str(doc_target[0]) + if isinstance(doc_target, list) + else prefix + doc_target + if self.config.doc_to_choice is None + or isinstance(doc_target, str) + else prefix + str(self.doc_to_choice(doc)[doc_target]), + } + ) + else: + # get fewshot context as one user turn + chat_history.append( + { + "role": "user", + "content": self.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ), + } + ) + + return chat_history + + def sample(self, n: int): + """ + Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. + """ + + return self.rnd.sample(self.docs, n) + + +class FirstNSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + Draw the first `n` samples in order from the specified split. + Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. + """ + assert n <= len(self.docs), ( + f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." + ) + return self.docs[:n] + + +class BalancedSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + TODO: this should return approximately class-balanced samples from our fewshot examples. + TODO: what order should they be in? maybe random? + """ + + pass + + +class ManualSampler(ContextSampler): + def sample(self, n: int) -> None: + """ """ + pass + + +SAMPLER_REGISTRY = { + "default": ContextSampler, + "first_n": FirstNSampler, +} + + +def get_sampler(name: str): + try: + return SAMPLER_REGISTRY[name] + except KeyError: + raise ValueError( + f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}" + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/task.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/task.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6321af0b2b8777e0322745a9875656ec194190 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/api/task.py @@ -0,0 +1,1881 @@ +import abc +import ast +import logging +import random +import re +from collections.abc import Callable +from copy import deepcopy +from dataclasses import asdict, dataclass +from inspect import getsource +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + Union, +) + +import datasets +import numpy as np +from tqdm import tqdm + +from dllm_eval import utils +from dllm_eval.api import samplers +from dllm_eval.api.instance import Instance, OutputType +from dllm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity +from dllm_eval.api.registry import ( + AGGREGATION_REGISTRY, + DEFAULT_METRIC_REGISTRY, + get_aggregation, + get_metric, + get_metric_aggregation, + is_higher_better, +) +from dllm_eval.caching.cache import load_from_cache, save_to_cache +from dllm_eval.filters import build_filter_ensemble +from dllm_eval.prompts import get_prompt + + +ALL_OUTPUT_TYPES = [ + "loglikelihood", + "multiple_choice", + "loglikelihood_rolling", + "generate_until", +] + +eval_logger = logging.getLogger(__name__) + + +@dataclass +class TaskConfig(dict): + # task naming/registry + task: Optional[str] = None + task_alias: Optional[str] = None + tag: Optional[Union[str, list]] = None + # HF dataset options. + # which dataset to use, + # and what splits for what purpose + custom_dataset: Optional[Callable] = None + dataset_path: Optional[str] = None + dataset_name: Optional[str] = None + dataset_kwargs: Optional[dict] = None + training_split: Optional[str] = None + validation_split: Optional[str] = None + test_split: Optional[str] = None + fewshot_split: Optional[str] = ( + None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?) + ) + # formatting / prompting options. + # see docs/advanced_task_guide.md for more info + process_docs: Optional[Callable] = None + doc_to_text: Optional[Union[Callable, str]] = None + doc_to_target: Optional[Union[Callable, str]] = None + doc_to_image: Union[Callable, str] = None + doc_to_audio: Union[Callable, str] = None + unsafe_code: bool = False + doc_to_choice: Optional[Union[Callable, str, dict, list]] = None + process_results: Optional[Union[Callable, str]] = None + use_prompt: Optional[str] = None + description: str = "" + target_delimiter: str = " " + fewshot_delimiter: str = "\n\n" + fewshot_config: Optional[dict] = None + # runtime configuration options + num_fewshot: Optional[int] = None + # scoring options + metric_list: Optional[list] = None + output_type: OutputType = "generate_until" + generation_kwargs: Optional[dict] = None + repeats: int = 1 + filter_list: Optional[Union[str, list]] = None + should_decontaminate: bool = False + doc_to_decontamination_query: Optional[str] = None + gen_prefix: Optional[str] = None + metadata: Optional[dict] = ( + None # by default, not used in the code. allows for users to pass arbitrary info to tasks + ) + + def __post_init__(self) -> None: + if self.generation_kwargs is not None: + if self.output_type != "generate_until": + eval_logger.warning( + f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!" + ) + + if "temperature" in self.generation_kwargs: + self.generation_kwargs["temperature"] = float( + self.generation_kwargs["temperature"] + ) + + if "until" not in self.generation_kwargs: + eval_logger.warning( + f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}" + ) + self.generation_kwargs["until"] = [self.fewshot_delimiter] + else: + if self.output_type == "generate_until": + # ensure that we greedily generate in absence of explicit arguments otherwise + self.generation_kwargs = { + "until": ( + None + if self.fewshot_delimiter is None + else [self.fewshot_delimiter] + ), + "do_sample": False, + "temperature": 0, + } + eval_logger.warning( + f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" + ) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + return setattr(self, item, value) + + def to_dict(self, keep_callable: bool = False) -> dict: + """dumps the current config as a dictionary object, as a printable format. + null fields will not be printed. + Used for dumping results alongside full task configuration + + :return: dict + A printable dictionary version of the TaskConfig object. + + # TODO: should any default value in the TaskConfig not be printed? + """ + cfg_dict = asdict(self) + # remove values that are `None` + for k, v in list(cfg_dict.items()): + if v is None: + cfg_dict.pop(k) + elif k == "metric_list": + for metric_dict in v: + for metric_key, metric_value in metric_dict.items(): + if callable(metric_value): + metric_dict[metric_key] = self.serialize_function( + metric_value, keep_callable=keep_callable + ) + cfg_dict[k] = v + elif callable(v): + cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) + return cfg_dict + + def serialize_function( + self, value: Union[Callable, str], keep_callable=False + ) -> Union[Callable, str]: + """Serializes a given function or string. + + If 'keep_callable' is True, the original callable is returned. + Otherwise, attempts to return the source code of the callable using 'getsource'. + """ + if keep_callable: + return value + else: + try: + return getsource(value) + except (TypeError, OSError): + return str(value) + + +class Task(abc.ABC): + """A task represents an entire benchmark including its dataset, problems, + answers, and evaluation methods. See BoolQ for a simple example implementation + + A `doc` can be any python object which represents one instance of evaluation. + This is usually a dictionary e.g. + {"question": ..., "answer": ...} or + {"question": ..., question, answer) + """ + + VERSION: Optional[Union[int, str]] = None + + # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub + # or a path to a custom `datasets` loading script. + DATASET_PATH: Optional[str] = None + + # The name of a subset within `DATASET_PATH`. + DATASET_NAME: Optional[str] = None + + OUTPUT_TYPE: Optional[OutputType] = None + + def __init__( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode: Optional[datasets.DownloadMode] = None, + config: Optional[Mapping] = None, # Union[dict, TaskConfig] + ) -> None: + """ + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.download(data_dir, cache_dir, download_mode) + self._training_docs: Optional[list] = None + self._fewshot_docs: Optional[list] = None + self._instances: Optional[List[Instance]] = None + + self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig() + + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + self.fewshot_rnd: Optional[random.Random] = ( + None # purposely induce errors in case of improper usage + ) + + def download( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode=None, + ) -> None: + """Downloads and returns the task dataset. + Override this method to download the dataset from a custom API. + + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + by setting the shell environment variable, `HF_DATASETS_CACHE`, + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + data_dir=data_dir, + cache_dir=cache_dir, + download_mode=download_mode, + ) + + @property + def config(self) -> TaskConfig: + """Returns the TaskConfig associated with this class.""" + return self._config + + @abc.abstractmethod + def has_training_docs(self): + """Whether the task has a training set""" + pass + + @abc.abstractmethod + def has_validation_docs(self): + """Whether the task has a validation set""" + pass + + @abc.abstractmethod + def has_test_docs(self): + """Whether the task has a test set""" + pass + + def training_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def validation_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def test_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def fewshot_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + if self.has_training_docs(): + return self.training_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + if self.config.get("num_fewshot", 0) > 0: + eval_logger.warning( + f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False" + ", using test_docs as fewshot_docs but this is not recommended." + ) + return self.test_docs() + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + @property + def instances(self) -> List[Instance]: + """After calling `task.build_all_requests()`, tasks + maintain a list of the dataset instances which will be evaluated. + """ + return self._instances + + def fewshot_examples(self, k, rnd): + if self._training_docs is None: + self._training_docs = list(self.training_docs()) + + return rnd.sample(self._training_docs, k) + + def doc_to_decontamination_query(self, doc): + raise NotImplementedError( + "Override doc_to_decontamination_query with document specific decontamination query." + ) + + @abc.abstractmethod + def doc_to_text(self, doc): + pass + + @abc.abstractmethod + def doc_to_target(self, doc): + pass + + # not an abstractmethod because not every language-only task has to implement this + def doc_to_image(self, doc): + raise NotImplementedError + + def doc_to_audio(self, doc): + raise NotImplementedError + + def doc_to_prefix(self, doc): + return "" + + def build_all_requests( + self, + *, + limit: Union[int, None] = None, + samples: Optional[List[int]] = None, + rank: int = 0, + world_size: int = 1, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + tokenizer_name: str = "", + ) -> None: + """Build a set of Instances for a task, and store them in task.instances""" + + # used with caching + og_limit = limit + + cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" + cache_key += "-chat_template" if apply_chat_template else "" + cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" + cache_key += ( + f"-system_prompt_hash{utils.hash_string(system_instruction)}" + if system_instruction is not None + else "" + ) + cache_key += f"-tokenizer{tokenizer_name}" + + cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests) + + if cache_requests and cached_instances and not rewrite_requests_cache: + cached_instances = cached_instances[:limit] + + flattened_instances = [ + instance + for instance_group in cached_instances + for instance in instance_group + ] + + self._instances = flattened_instances + return + + eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") + + instances = [] + + # process all documents when caching is specified for simplicity + if ( + cache_requests + and (not cached_instances or rewrite_requests_cache) + and limit is not None + ): + limit = None + + doc_id_docs = list( + self.doc_iterator( + rank=rank, limit=limit, samples=samples, world_size=world_size + ) + ) + + num_docs = len(doc_id_docs) + + for doc_id, doc in tqdm( + doc_id_docs, + total=num_docs, + ): + # sample fewshot context #TODO: need to offset doc_id by rank now! + fewshot_ctx = self.fewshot_context( + doc, + num_fewshot=0 + if self.config.num_fewshot is None + else self.config.num_fewshot, + system_instruction=system_instruction, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, + chat_template=chat_template, + gen_prefix=self.doc_to_prefix(doc), + ) + + # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute + inst = self.construct_requests( + doc=doc, + ctx=fewshot_ctx, + metadata=(self.config["task"], doc_id, self.config.repeats), + apply_chat_template=apply_chat_template, + chat_template=chat_template, + ) + + if not isinstance(inst, list): + inst = [inst] + + instances.append(inst) + + # now flatten, this is to allow slicing to work with pickles + + sliced_instances = instances[:og_limit] + + flattened_instances = [ + instance + for instance_group in sliced_instances + for instance in instance_group + ] + + self._instances = flattened_instances + + if len(self._instances) == 0: + raise ValueError("task.build_requests() did not find any docs!") + + if cache_requests and (not cached_instances or rewrite_requests_cache): + save_to_cache(file_name=cache_key, obj=instances) + + @abc.abstractmethod + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + :param doc_idx: int + The index of a document within `self.test_docs()` or `self.validation_docs()`, + whichever is the main split used. + :param repeats: int + TODO: update this docstring + The number of times each instance in a dataset is inferred on. Defaults to 1, + can be increased for techniques like majority voting. + """ + pass + + @abc.abstractmethod + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + pass + + @abc.abstractmethod + def aggregation(self): + """ + :returns: {str: [metric_score] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metric scores + """ + pass + + @abc.abstractmethod + def higher_is_better(self): + """ + :returns: {str: bool} + A dictionary where keys are the names of submetrics and values are + whether a higher value of the submetric is better + """ + pass + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @classmethod + def count_bytes(cls, doc): + """Used for byte-level perplexity metrics in rolling loglikelihood""" + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc): + """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) + + @utils.positional_deprecated + def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs): + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param rnd: random.Random + The pseudo-random number generator used to randomly sample examples. + WARNING: This is currently a required arg although it's optionalized with a default `None`. + :param description: str + The task's description that will be prepended to the fewshot examples. + :returns: str + The fewshot context. + """ + if rnd is None: + if self.fewshot_rnd is not None: + rnd = self.fewshot_rnd + else: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd`" + ) + + description = description if description else "" + + if num_fewshot == 0: + labeled_examples = "" + else: + # for sets with no training docs, draw from other set *but ensure no overlap with current doc* + if self.has_training_docs(): + fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) + else: + if self._fewshot_docs is None: + self._fewshot_docs = list( + self.validation_docs() + if self.has_validation_docs() + else self.test_docs() + ) + + fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = ( + "\n\n".join( + [ + self.doc_to_text(doc) + self.doc_to_target(doc) + for doc in fewshotex + ] + ) + + "\n\n" + ) + + example = self.doc_to_text(doc) + return description + labeled_examples + example + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def dump_config(self) -> dict: + """Returns the config as a dictionary.""" + # TODO: this should only return the overrides applied to a non-YAML task's configuration. + # (num_fewshot) + return self.config.to_dict() + + def set_config(self, key: str, value: Any, update: bool = False) -> None: + """Set or update the configuration for a given key.""" + if key is None: + raise ValueError("Key must be provided.") + + if update: + current_value = getattr(self._config, key, {}) + if not isinstance(current_value, dict): + raise TypeError( + f"Expected a dict for key '{key}', got {type(current_value).__name__} instead." + ) + current_value.update(value) + else: + setattr(self._config, key, value) + + def override_metric(self, metric_name: str) -> None: + """ + Override the default metrics used for evaluation with custom metrics. + + Parameters: + - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics. + """ + ( + self._metric_fn_list, + self._aggregation_list, + self._metric_fn_kwargs, + self._higher_is_better, + ) = ({}, {}, {}, {}) + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._aggregation_list[metric_name] = get_metric_aggregation(metric_name) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + self._metric_fn_kwargs[metric_name] = {} + if not isinstance(self, ConfigurableTask): + self.process_results = lambda x, y: {metric_name: get_metric(metric_name)} + self.aggregation = lambda: { + metric_name: get_metric_aggregation(metric_name) + } + setattr(self._config, "metric_list", [{"metric": metric_name}]) + setattr(self._config, "process_results", None) + + def set_fewshot_seed(self, seed: Optional[int] = None) -> None: + self.fewshot_rnd = random.Random(seed) + if hasattr(self, "sampler"): + self.sampler.rnd = self.fewshot_rnd + + @property + def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: + if self.has_test_docs(): + return self.test_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + raise ValueError( + f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" + ) + + def doc_iterator( + self, + *, + rank: int = 0, + limit: Union[int, None] = None, + world_size: int = 1, + samples: Optional[List[int]] = None, + ) -> Iterator[Tuple[int, Any]]: + if samples: + n = len(self.eval_docs) + assert all([e < n for e in samples]), ( + f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}." + ) + eval_logger.info( + f"{self.config.task}: Evaluating on {len(samples)} examples" + ) + doc_iterator = utils.create_iterator( + enumerate(x for i, x in enumerate(self.eval_docs) if i in samples), + rank=int(rank), + limit=None, # limit does not matter here since we are selecting samples directly + world_size=int(world_size), + ) + else: + limit = int(limit) if limit else None + doc_iterator = utils.create_iterator( + enumerate(self.eval_docs), + rank=int(rank), + limit=limit, + world_size=int(world_size), + ) + return doc_iterator + + +class ConfigurableTask(Task): + VERSION = "Yaml" + OUTPUT_TYPE = None + CONFIG = None + + def __init__( + self, + data_dir=None, + cache_dir=None, + download_mode=None, + config: Optional[dict] = None, + ) -> None: # TODO no super() call here + # Get pre-configured attributes + self._config = self.CONFIG + + # Use new configurations if there was no preconfiguration + if self.config is None: + self._config = TaskConfig(**config) + # Overwrite configs + else: + if config is not None: + self._config.__dict__.update(config) + + if self.config is None: + raise ValueError( + "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" + ) + + if isinstance(self.config.metadata, dict): + if "version" in self.config.metadata: + self.VERSION = self.config.metadata["version"] + + if self.config.output_type is not None: + if self.config.output_type not in ALL_OUTPUT_TYPES: + raise ValueError( + f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'" + ) + self.OUTPUT_TYPE = self.config.output_type + + if self.config.doc_to_image is not None: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.doc_to_audio: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.unsafe_code is not False: + self.UNSAFE_CODE = True + + if self.config.dataset_path is not None: + self.DATASET_PATH = self.config.dataset_path + + if self.config.dataset_name is not None: + self.DATASET_NAME = self.config.dataset_name + + self._metric_fn_list = {} + self._metric_fn_kwargs = {} + self._aggregation_list = {} + self._higher_is_better = {} + + if self.config.metric_list is None: + # TODO: handle this in TaskConfig.__post_init__ ? + _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type] + + for metric_name in _metric_list: + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._metric_fn_kwargs[metric_name] = {} + self._aggregation_list[metric_name] = get_metric_aggregation( + metric_name + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + else: + for metric_config in self.config.metric_list: + if "metric" not in metric_config: + raise ValueError( + "'metric' key not provided for an entry in 'metric_list', must be specified!" + ) + metric_name = metric_config["metric"] + kwargs = { + key: metric_config[key] + for key in metric_config + if key + not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"] + } + hf_evaluate_metric = ( + "hf_evaluate" in metric_config + and metric_config["hf_evaluate"] is True + ) + + if self.config.process_results is not None: + self._metric_fn_list[metric_name] = None + self._metric_fn_kwargs[metric_name] = {} + elif callable(metric_name): + metric_fn = metric_name.__call__ + metric_name = metric_name.__name__ + self._metric_fn_list[metric_name] = metric_fn + self._metric_fn_kwargs[metric_name] = kwargs + else: + self._metric_fn_list[metric_name] = get_metric( + metric_name, hf_evaluate_metric + ) + self._metric_fn_kwargs[metric_name] = kwargs + + if "aggregation" in metric_config: + agg_name = metric_config["aggregation"] + if isinstance(agg_name, str): + self._aggregation_list[metric_name] = get_aggregation(agg_name) + elif callable(agg_name): # noqa: E721 + self._aggregation_list[metric_name] = metric_config[ + "aggregation" + ] + else: + INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} + metric_agg = get_metric_aggregation(metric_name) + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. " + f"using default " + f"aggregation={INV_AGG_REGISTRY[metric_agg]}" + ) + self._aggregation_list[metric_name] = metric_agg + + if "higher_is_better" in metric_config: + self._higher_is_better[metric_name] = metric_config[ + "higher_is_better" + ] + else: + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. " + f"using default " + f"higher_is_better={is_higher_better(metric_name)}" + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + + self.download(self.config.dataset_kwargs) + self._training_docs = None + self._fewshot_docs = None + + if self.config.filter_list is not None: + self._filters = [] + for filter_config in self.config.filter_list: + filter_name = filter_config["name"] + filter_functions = filter_config["filter"] + components = [] + for function in filter_functions: + kwargs = { + key: function[key] for key in function if key != "function" + } + components.append([function["function"], kwargs]) + filter_pipeline = build_filter_ensemble(filter_name, components) + self._filters.append(filter_pipeline) + else: + # TODO: handle repeats in a more general way rather than just discarding + eval_logger.debug( + "No custom filters defined. Using default 'take_first' filter for handling repeats." + ) + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + + if self.config.use_prompt is not None: + eval_logger.info(f"loading prompt {self.config.use_prompt}") + self.prompt = get_prompt( + self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME + ) + else: + self.prompt = None + + if self.fewshot_docs() is not None: + self.fewshot_rnd = ( + random.Random() + ) # setting with no seed, to be overridden at a later time + config_sampler: Union[str, Callable] = ( + self.config.fewshot_config.get("sampler", "default") + if self.config.fewshot_config + else "default" + ) + if isinstance(config_sampler, str): + self.sampler = samplers.get_sampler(config_sampler)( + list(self.fewshot_docs()), self, rnd=self.fewshot_rnd + ) + elif callable(config_sampler) and issubclass( + config_sampler, samplers.ContextSampler + ): + self.sampler = config_sampler( + docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd + ) + else: + raise TypeError( + f"fewshot_config.sampler should be a string or callable of ContextSampler type, " + f"not {type(config_sampler)}" + ) + + self.task_docs = self.eval_docs + + # Test One Doc + self.features = list(self.task_docs.features.keys()) + self.multiple_input = 0 + self.multiple_target = 0 + test_doc = self.task_docs[0] + test_text = self.doc_to_text(test_doc) + test_target = self.doc_to_target(test_doc) + + if self.config.doc_to_choice is not None: + test_choice = self.doc_to_choice(test_doc) + if not isinstance(test_choice, list): + eval_logger.error("doc_to_choice must return list") + else: + num_choice = len(test_choice) + + if isinstance(test_text, int): + eval_logger.debug( + "doc_to_text returned an int. Assuming multiple inputs." + ) + self.multiple_input = num_choice + else: + test_choice = None + + if isinstance(test_target, list): + eval_logger.debug( + "doc_to_target returned a list. Assuming multiple targets." + ) + self.multiple_target = len(test_target) + else: + if (isinstance(test_target, int)) and (test_choice is not None): + test_target = test_choice[test_target] + else: + test_target = str(test_target) + + if test_choice is not None: + check_choices = test_choice + else: + check_choices = [test_target] + if self.config.doc_to_choice is not None: + for choice in check_choices: + choice_has_whitespace = True if choice[0].isspace() else False + delimiter_has_whitespace = ( + True + if self.config.target_delimiter.rstrip() + != self.config.target_delimiter + else False + ) + + if delimiter_has_whitespace and choice_has_whitespace: + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace' + ) + elif (not delimiter_has_whitespace) and (not choice_has_whitespace): + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' + ) + + def download( + self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs + ) -> None: + if isinstance(self.config.custom_dataset, Callable): + eval_logger.warning( + f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager." + + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme." + ) + self.dataset = self.config.custom_dataset( + **(self.config.metadata or {}), **(self.config.dataset_kwargs or {}) + ) + else: + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + **dataset_kwargs if dataset_kwargs is not None else {}, + ) + + def has_training_docs(self) -> bool: + if self.config.training_split is not None: + return True + else: + return False + + def has_validation_docs(self) -> bool: + if self.config.validation_split is not None: + return True + else: + return False + + def has_test_docs(self) -> bool: + if self.config.test_split is not None: + return True + else: + return False + + def training_docs(self) -> datasets.Dataset: + if self.has_training_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.training_split] + ) + return self.dataset[self.config.training_split] + + def validation_docs(self) -> datasets.Dataset: + if self.has_validation_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.validation_split] + ) + return self.dataset[self.config.validation_split] + + def test_docs(self) -> datasets.Dataset: + if self.has_test_docs(): + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.test_split]) + return self.dataset[self.config.test_split] + + def fewshot_docs(self): + if self.config.fewshot_split is not None: + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.fewshot_split]) + return self.dataset[self.config.fewshot_split] + elif ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("samples", None) is not None + ): + if isinstance(self.config.fewshot_config["samples"], list): + return self.config.fewshot_config["samples"] + elif callable(self.config.fewshot_config["samples"]): + return self.config.fewshot_config["samples"]() + else: + raise Exception( + "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list." + ) + else: + if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): + eval_logger.warning( + f"[Task: {self.config.task}] " + "num_fewshot > 0 but fewshot_split is None. " + "using preconfigured rule." + ) + return super().fewshot_docs() + + @staticmethod + def append_target_question( + labeled_examples: List[Dict[str, str]], + question: str, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ) -> None: + """Adds a target question to the labeled examples list. + If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. + Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant. + """ + if not fewshot_as_multiturn: + # if no messages or last message is system, append as new user entry + if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system": + labeled_examples.append({"role": "user", "content": question}) + # if last message is user, append to it to avoid two user messages in a row + else: + labeled_examples[-1]["content"] += question + else: + # if fewshot_as_multiturn is True, append as next user entry (last is always assistant) + labeled_examples.append({"role": "user", "content": question}) + if gen_prefix: + labeled_examples.append({"role": "assistant", "content": gen_prefix}) + + @utils.positional_deprecated + def fewshot_context( + self, + doc: dict, + num_fewshot: int, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + gen_prefix: Optional[str] = None, + ) -> Union[str, List[str]]: + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param system_instruction: str + System instruction to be applied to the prompt. + :param apply_chat_template: bool + Whether to apply the chat template to the fewshot context. + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param chat_template: + callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. + :param gen_prefix: + String to append after the <|assistant|> token. + :returns: str + The fewshot context. + """ + if apply_chat_template: + labeled_examples = [] + else: + labeled_examples = "" + + # get task description + if description := self.config.description: + description = utils.apply_template(self.config.description, doc) + + # create system prompt based on the provided system instruction and description + if system_instruction is not None and description: + system_prompt = ( + f"{system_instruction}{self.sampler.fewshot_delimiter}{description}" + ) + elif system_instruction is not None: + system_prompt = system_instruction + elif description: + system_prompt = description + else: + system_prompt = "" + + # add system prompt if specified + if system_prompt: + if apply_chat_template: + labeled_examples.append({"role": "system", "content": system_prompt}) + else: + labeled_examples = system_prompt + # if few-shot - append examples after the system prompt + if num_fewshot > 0: + if apply_chat_template: + labeled_examples.extend( + self.sampler.get_chat_context( + doc, + num_fewshot, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + ) + else: + labeled_examples += self.sampler.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ) + + example = self.doc_to_text(doc) + if apply_chat_template: + if self.multiple_input: + # TODO: append prefill? + if not labeled_examples: + return "" + return chat_template(labeled_examples) + if isinstance(example, str): + self.append_target_question( + labeled_examples, + example, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # for loglikelihood create a list of questions with appended choices + elif isinstance(example, list): + labeled_examples_list = [] + # copy chat history for each example and append the answer + for ex in example: + chat = deepcopy(labeled_examples) + self.append_target_question( + chat, + ex, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # TODO: append prefill? + labeled_examples_list.append( + chat_template( + chat, + add_generation_prompt=False if gen_prefix else True, + ) + ) + return labeled_examples_list + # if example is an integer, append the choice or convert to string + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + self.append_target_question( + labeled_examples, + choices[example], + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + else: + self.append_target_question( + labeled_examples, + str(example), + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # return lm.apply_chat_template(labeled_examples) + return chat_template( + labeled_examples, + add_generation_prompt=False if gen_prefix else True, + ) + else: + prefix = ( + self.config.target_delimiter + gen_prefix + if gen_prefix is not None + else "" + ) + if self.multiple_input: + return labeled_examples + if isinstance(example, str): + return labeled_examples + example + prefix + elif isinstance(example, list): + return [labeled_examples + ex + prefix for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example] + prefix + else: + return labeled_examples + str(example) + prefix + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def should_decontaminate(self): + return self.config.should_decontaminate + + def doc_to_decontamination_query(self, doc: dict): + if self.config.should_decontaminate: + if self.config.doc_to_decontamination_query is None: + return self.doc_to_text(doc) + else: + doc_to_decontamination_query = self.config.doc_to_decontamination_query + if doc_to_decontamination_query in self.features: + return doc[doc_to_decontamination_query] + elif callable(doc_to_decontamination_query): + return doc_to_decontamination_query(doc) + else: + return ast.literal_eval( + utils.apply_template( + self.config.doc_to_decontamination_query, doc + ) + ) + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + def doc_to_text(self, doc, doc_to_text=None): + if self.prompt is not None: + doc_to_text = self.prompt + elif doc_to_text is not None: + doc_to_text = doc_to_text + else: + doc_to_text = self.config.doc_to_text + + if isinstance(doc_to_text, int): + return doc_to_text + elif isinstance(doc_to_text, str): + if doc_to_text in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_text]] + # else: + return doc[doc_to_text] + else: + text_string = utils.apply_template(doc_to_text, doc) + if text_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(text_string) + else: + return text_string + elif callable(doc_to_text): + return doc_to_text(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_text, "apply"): + applied_prompt = doc_to_text.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[0] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + print(type(doc_to_text)) + raise TypeError + + def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]: + if self.prompt is not None: + doc_to_target = self.prompt + elif doc_to_target is not None: + doc_to_target = doc_to_target + else: + doc_to_target = self.config.doc_to_target + + if isinstance(doc_to_target, int): + return doc_to_target + elif isinstance(doc_to_target, str): + if doc_to_target in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_target]] + # else: + return doc[doc_to_target] + else: + target_string = utils.apply_template(doc_to_target, doc) + if target_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(target_string) + elif ( + len(target_string) >= 2 + and (target_string[0] == "[") + and (target_string[-1] == "]") + ): + try: + return ast.literal_eval(target_string) + except (SyntaxError, ValueError): + return target_string + else: + return target_string + elif isinstance(doc_to_target, list): + return doc_to_target + elif callable(doc_to_target): + return doc_to_target(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_target, "apply"): + applied_prompt = doc_to_target.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[1] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + raise TypeError + + def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]: + if self.prompt is not None: + doc_to_choice = self.prompt + elif doc_to_choice is not None: + doc_to_choice = doc_to_choice + elif self.config.doc_to_choice is None: + eval_logger.error("doc_to_choice was called but not set in config") + else: + doc_to_choice = self.config.doc_to_choice + + if isinstance(doc_to_choice, str): + if doc_to_choice in self.features: + return doc[doc_to_choice] + else: + return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) + elif isinstance(doc_to_choice, list): + return doc_to_choice + elif isinstance(doc_to_choice, dict): + return list(doc_to_choice.values()) + elif callable(doc_to_choice): + return doc_to_choice(doc) + elif hasattr(doc_to_choice, "get_answer_choices_list"): + return doc_to_choice.get_answer_choices_list(doc) + else: + raise TypeError + + def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: + if doc_to_image is not None: + doc_to_image = doc_to_image + elif self.config.doc_to_image is not None: + doc_to_image = self.config.doc_to_image + else: + return None + + if isinstance(doc_to_image, list): + image_feature = [ + self.doc_to_image(doc, feature) for feature in doc_to_image + ] + return [feature for feature in image_feature if feature is not None] + elif isinstance(doc_to_image, str): + if doc_to_image in self.features: + return doc[doc_to_image] + else: + return ast.literal_eval(utils.apply_template(doc_to_image, doc)) + elif callable(doc_to_image): + return doc_to_image(doc) + else: + return None + + def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]: + if doc_to_audio is not None: + doc_to_audio = doc_to_audio + elif self.config.doc_to_audio is not None: + doc_to_audio = self.config.doc_to_audio + else: + return None + + if isinstance(doc_to_audio, list): + audio_feature = [ + self.doc_to_audio(doc, feature) for feature in doc_to_audio + ] + return [feature for feature in audio_feature if feature is not None] + elif isinstance(doc_to_audio, str): + if doc_to_audio in self.features: + return doc[doc_to_audio] + else: + return ast.literal_eval(utils.apply_template(doc_to_audio, doc)) + elif callable(doc_to_audio): + return doc_to_audio(doc) + else: + return None + + def doc_to_prefix(self, doc): + if (gen_prefix := self.config.gen_prefix) is not None: + if gen_prefix in self.features: + return doc[gen_prefix] + else: + return utils.apply_template(gen_prefix, doc) + return None + + def construct_requests( + self, doc: dict, ctx: str, **kwargs + ) -> Union[List[Instance], Instance]: + apply_chat_template = kwargs.pop("apply_chat_template", False) + chat_template: Callable | None = kwargs.pop("chat_template", None) + + aux_arguments = None + + if self.OUTPUT_TYPE == "loglikelihood": + arguments = (ctx, self.doc_to_target(doc)) + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + arguments = (self.doc_to_target(doc),) + elif self.OUTPUT_TYPE == "multiple_choice": + choices = self.doc_to_choice(doc) + target_delimiter = self.config.target_delimiter + if apply_chat_template: + target_delimiter = "" + if self.multiple_input: + # If there are multiple inputs, choices are placed in the ctx + # apply chat_template to choices if apply_chat_template + cont = self.doc_to_target(doc) + + arguments = [ + ( + ctx + + ( + chat_template([{"role": "user", "content": choice}]) + if apply_chat_template + else choice + ), + f"{target_delimiter}{cont}", + ) + for choice in choices + ] + else: + # Otherwise they are placed in the continuation + arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] + + # TODO: we should raise a warning telling users this will at most ~2x runtime. + if "acc_mutual_info" in self._metric_fn_list.keys(): + # if we are calculating multiple choice accuracy + # using mutual information instead of raw loglikelihood as metric, need unconditional lls. + + # here mutual info refers to calculating + # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) + # in other words normalizing by subtracting the unconditional logprob of each choice. + # TODO: should these be strided? will have to modify the processing in process_results if so + aux_arguments = [ + ("", f"{target_delimiter}{choice}") for choice in choices + ] + + arguments.extend(aux_arguments) + + elif self.OUTPUT_TYPE == "generate_until": + arguments = (ctx, deepcopy(self.config.generation_kwargs)) + + multimodal_arg = {} + if ( + self.config.doc_to_image + ): # TODO: ensure that non-multimodal tasks aren't getting visual args + multimodal_arg = { + **multimodal_arg, + **{"visual": self.doc_to_image(doc)}, + } + + if ( + self.config.doc_to_audio + ): # TODO: ensure that non-multimodal tasks aren't getting audio args + multimodal_arg = { + **multimodal_arg, + **{"audio": self.doc_to_audio(doc)}, + } + + if bool(multimodal_arg): + if isinstance(arguments, list): + arguments = [arg + (multimodal_arg,) for arg in arguments] + else: + arguments = arguments + (multimodal_arg,) + + if self.OUTPUT_TYPE == "multiple_choice": + request_list = [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=arg, + idx=i, + **kwargs, + ) + for i, arg in enumerate(arguments) + ] + + return request_list + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=arguments, + idx=0, + **kwargs, + ) + + def process_results(self, doc, results): + if callable(self.config.process_results): + return self.config.process_results(doc, results) + + result_dict = {} + use_metric = list(self._metric_fn_list.keys()) + if self.OUTPUT_TYPE == "loglikelihood": + results = results[0] + ll, is_greedy = results + return { + **({"perplexity": ll} if "perplexity" in use_metric else {}), + **({"acc": int(is_greedy)} if "acc" in use_metric else {}), + } + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + (loglikelihood,) = results + _words = self.count_words(self.doc_to_target(doc)) + _bytes = self.count_bytes(self.doc_to_target(doc)) + return { + **( + {"word_perplexity": (loglikelihood, _words)} + if "word_perplexity" in use_metric + else {} + ), + **( + {"byte_perplexity": (loglikelihood, _bytes)} + if "byte_perplexity" in use_metric + else {} + ), + **( + {"bits_per_byte": (loglikelihood, _bytes)} + if "bits_per_byte" in use_metric + else {} + ), + } + elif self.OUTPUT_TYPE == "multiple_choice": + lls, is_greedy = zip(*results) + + # retrieve choices in List[str] form, to compute choice lengths, etc. + choices = self.doc_to_choice(doc) + completion_len = np.array([float(len(i)) for i in choices]) + + if ( + 2 * len(choices) == len(lls) + and "acc_mutual_info" in self._metric_fn_list.keys() + ): + # then we are doing mutual info. + # this stores the "dryrun" / unconditional answer loglikelihoods + # as we extend the args list with unconditional ("", continuation) pairs + lls_unconditional = lls[len(choices) :] + if len(lls_unconditional) != len(choices): + raise ValueError + # and this stores our "regular" conditional loglikelihoods + lls = lls[: len(choices)] + + pred = np.argmax(lls) + pred_norm = np.argmax(lls / completion_len) + + if self.multiple_input: + gold = self.doc_to_text(doc) + else: + gold = self.doc_to_target(doc) + + gold_index_error = False + if isinstance(gold, list): + gold = [i if i < len(choices) else -100 for i in gold] + if -100 in gold: + gold_index_error = True + else: + if isinstance(gold, int): + gold = gold if gold < len(choices) else -100 + elif isinstance(gold, str): + gold = choices.index(gold) if gold in choices else -100 + + if gold == -100: + gold_index_error = True + + if gold_index_error: + eval_logger.warning( + f"Label index was not in within range of available choices," + f"Sample:\n\n{doc}\n\n" + ) + + if self.multiple_target: + acc = 1.0 if pred in gold else 0.0 + acc_norm = 1.0 if pred_norm in gold else 0.0 + exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) + else: + acc = 1.0 if pred == gold else 0.0 + acc_norm = 1.0 if pred_norm == gold else 0.0 + # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly + exact_match = int(is_greedy[gold]) if gold != -100 else 0 + + prob_norm = utils.softmax(lls) + + # TODO use keyword arguments to the metric? + # gold, pred, norm stuff, the original lls, + result_dict = { + **({"acc": acc} if "acc" in use_metric else {}), + **({"f1": (gold, pred)} if "f1" in use_metric else {}), + **({"mcc": (gold, pred)} if "mcc" in use_metric else {}), + **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), + **({"exact_match": exact_match} if "exact_match" in use_metric else {}), + **( + {"brier_score": (gold, prob_norm)} + if "brier_score" in use_metric + else {} + ), + } + + if "acc_mutual_info" in use_metric: + lls_mutual_info = [ + ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) + ] + acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 + result_dict["acc_mutual_info"] = acc_mutual_info + + elif self.OUTPUT_TYPE == "generate_until": + gold = self.doc_to_target(doc) + result = results[0] + if self.config.doc_to_choice is not None: + # If you set doc_to_choice, + # it assumes that doc_to_target returns a number. + choices = self.doc_to_choice(doc) + gold = choices[gold] + # we expect multiple_targets to be a list. + elif self.multiple_target: + gold = list(gold) + # TODO: handle this better + elif type(gold) is not type(result) and not ( + "bypass" in self._metric_fn_list.keys() or isinstance(result, list) + ): + # cast gold to the same type as result + gold = type(result)(gold) + + for metric in self._metric_fn_list.keys(): + if self.multiple_target: + # in the case where we have multiple targets, + # return true if any are true + # TODO: this may break for multipLe_target, non zero-or-1 metrics + scores = [] + if not isinstance(gold, list): + # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer + # print(gold) + gold = [gold] + if metric == "exact_match": + result = [result for _ in range(len(gold))] + scores = self._metric_fn_list[metric]( + references=gold, + predictions=result, + **self._metric_fn_kwargs[metric], + )[metric] + result_score = 1.0 if scores > 0.0 else 0.0 + else: + for gold_option in gold: + try: + result_score = self._metric_fn_list[metric]( + references=[gold_option], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except ( + TypeError + ): # TODO: this is hacky and I don't want to do it + result_score = self._metric_fn_list[metric]( + [gold_option, result] + ) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + result_score = result_score[metric] + scores.append(result_score) + if any(scores): + result_score = 1.0 + else: + result_score = 0.0 + else: + try: + result_score = self._metric_fn_list[metric]( + references=[gold], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics + result_score = self._metric_fn_list[metric]([gold, result]) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + # This allows for multiple metrics to be returned from the same function + for k, v in result_score.items(): + result_dict[k] = v + else: + result_dict[metric] = result_score + else: + raise ValueError( + f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", + "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'", + ) + + return result_dict + + def aggregation(self) -> dict: + return self._aggregation_list + + def higher_is_better(self) -> dict: + return self._higher_is_better + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @property + def task_name(self) -> Any: + return getattr(self.config, "task", None) + + def __repr__(self): + return ( + f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," + f"output_type={self.OUTPUT_TYPE}," + f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," + f"num_samples={len(self.eval_docs)})" + ) + + +class MultipleChoiceTask(Task): + OUTPUT_TYPE = "loglikelihood" + + def doc_to_target(self, doc: dict) -> str: + return " " + doc["choices"][doc["gold"]] + + def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]: + # TODO: add mutual info here? + return [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=(ctx, " {}".format(choice)), + idx=i, + **kwargs, + ) + for i, choice in enumerate(doc["choices"]) + ] + + def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict: + results = [ + res[0] for res in results + ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? + gold = doc["gold"] + + acc = 1.0 if np.argmax(results) == gold else 0.0 + completion_len = np.array([float(len(i)) for i in doc["choices"]]) + acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 + + return { + "acc": acc, + "acc_norm": acc_norm, + } + + def higher_is_better(self) -> dict: + return { + "acc": True, + "acc_norm": True, + } + + def aggregation(self) -> dict: + return { + "acc": mean, + "acc_norm": mean, + } + + +class PerplexityTask(Task): + OUTPUT_TYPE = "loglikelihood_rolling" + + def has_training_docs(self) -> bool: + return False + + def fewshot_examples(self, k: int, rnd) -> List: + if k != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + return [] + + def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]: + if num_fewshot != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + + return "" + + def higher_is_better(self) -> dict: + return { + "word_perplexity": False, + "byte_perplexity": False, + "bits_per_byte": False, + } + + def doc_to_decontamination_query(self, doc): + return doc + + def doc_to_text(self, doc) -> str: + return "" + + def doc_to_target(self, doc): + return doc + + def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs): + if bool(ctx): + raise ValueError + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=(self.doc_to_target(doc),), + idx=0, + **kwargs, + ) + + def process_results(self, doc: dict, results: Tuple[float]) -> dict: + (loglikelihood,) = results + words = self.count_words(self.doc_to_target(doc)) + bytes_ = self.count_bytes(self.doc_to_target(doc)) + return { + "word_perplexity": (loglikelihood, words), + "byte_perplexity": (loglikelihood, bytes_), + "bits_per_byte": (loglikelihood, bytes_), + } + + def aggregation(self) -> dict: + return { + "word_perplexity": weighted_perplexity, + "byte_perplexity": weighted_perplexity, + "bits_per_byte": bits_per_byte, + } + + @classmethod + def count_bytes(cls, doc) -> int: + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc) -> int: + """Downstream tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/caching/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/caching/cache.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/caching/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d293b0ff8b1ebac186f5ac078cdb49227562db --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/caching/cache.py @@ -0,0 +1,59 @@ +import hashlib +import logging +import os + +import dill + + +eval_logger = logging.getLogger(__name__) + + +MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) + +OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH") + + +PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache" + +# This should be sufficient for uniqueness +HASH_INPUT = "EleutherAI-lm-evaluation-harness" + +HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest() + +FILE_SUFFIX = f".{HASH_PREFIX}.pickle" + + +def load_from_cache(file_name: str, cache: bool = False): + if not cache: + return + try: + path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + with open(path, "rb") as file: + cached_task_dict = dill.loads(file.read()) + return cached_task_dict + + except Exception: + eval_logger.debug(f"{file_name} is not cached, generating...") + pass + + +def save_to_cache(file_name, obj): + if not os.path.exists(PATH): + os.mkdir(PATH) + + file_path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + eval_logger.debug(f"Saving {file_path} to cache...") + with open(file_path, "wb") as file: + file.write(dill.dumps(obj)) + + +# NOTE the "key" param is to allow for flexibility +def delete_cache(key: str = ""): + files = os.listdir(PATH) + + for file in files: + if file.startswith(key) and file.endswith(FILE_SUFFIX): + file_path = f"{PATH}/{file}" + os.unlink(file_path) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/archiver.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/archiver.py new file mode 100644 index 0000000000000000000000000000000000000000..c132232116c2ae5f5ab1dc3a2a0afc0dbd4ef1bd --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/archiver.py @@ -0,0 +1,174 @@ +import datetime +import io +import json +import mmap +import os +from pathlib import Path +from typing import Any + +import jsonlines +import tqdm +import zstandard + + +def json_serial(obj: Any) -> str: + """JSON serializer for objects not serializable by default json code""" + + if isinstance(obj, (datetime.datetime,)): + return obj.isoformat() + raise TypeError("Type %s not serializable" % type(obj)) + + +# Modified version of lm_dataformat Archive for single file. +class Archive: + def __init__(self, file_path: str, compression_level: int = 3) -> None: + self.file_path = file_path + dir_name = os.path.dirname(file_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + self.fh = open(self.file_path, "wb") + self.cctx = zstandard.ZstdCompressor(level=compression_level) + self.compressor = self.cctx.stream_writer(self.fh) + + def add_data(self, data, meta=None) -> None: + if meta is None: + meta = {} + self.compressor.write( + json.dumps({"text": data, "meta": meta}, default=json_serial).encode( + "UTF-8" + ) + + b"\n" + ) + + def commit(self) -> None: + self.compressor.flush(zstandard.FLUSH_FRAME) + self.fh.flush() + self.fh.close() + + +# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. +class Reader: + def __init__(self) -> None: + pass + + def read( + self, + file, + get_meta: bool = False, + autojoin_paragraphs: bool = True, + para_joiner: str = "\n\n", + ): + with open(file, "rb") as fh: + self.fh = fh + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(fh)) + rdr = jsonlines.Reader(reader) + for ob in rdr: + # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility. + if isinstance(ob, str): + assert not get_meta + yield ob + continue + + text = ob["text"] + + if autojoin_paragraphs and isinstance(text, list): + text = para_joiner.join(text) + + if get_meta: + yield text, (ob["meta"] if "meta" in ob else {}) + else: + yield text + + +class TextArchive: + def __init__(self, file_path, mode: str = "rb+") -> None: + self.file_path = file_path + dir_name = os.path.dirname(file_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + + if not os.path.exists(file_path): + Path(file_path).touch() + + self.fh = open(self.file_path, mode) + + def add_data(self, data) -> None: + self.fh.write(data.encode("UTF-8") + b"\n") + + def commit(self) -> None: + self.fh.flush() + self.fh.close() + + +class TextReader: + def __init__(self, file_path) -> None: + self.file_path = file_path + + # Optimized mmap read with infrequent tqdm updates to maintain speed + # Tested up to 250MB/s. + def read_tqdm(self, update_frequency: int = 10000): + current_file_position = 0 + line_counter = 0 + with ( + open(self.file_path, "r", encoding="utf-8") as fh, + tqdm.tqdm( + total=os.path.getsize(self.file_path), + dynamic_ncols=True, + unit="byte", + unit_scale=1, + ) as progress, + ): + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + line_counter += 1 + if line_counter == update_frequency: + new_file_pos = mmap_obj.tell() + bytes_read = new_file_pos - current_file_position + current_file_position = new_file_pos + progress.update(bytes_read) + line_counter = 0 + yield line[:-1] + + def read_and_tell(self): + current_file_position = 0 + with open(self.file_path, "r", encoding="utf8") as fh: + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + new_file_pos = mmap_obj.tell() + raw_bytes_read = new_file_pos - current_file_position + current_file_position = new_file_pos + yield line[:-1], raw_bytes_read + + def read(self): + with open(self.file_path, "r", encoding="utf8") as fh: + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + yield line[:-1] + + def read_slow(self): + with open(self.file_path, "r", encoding="utf8") as fh: + while True: + line = fh.readline() + if line == -1 or line == "": + break + else: + yield line[:-1] + + +# Optimized for speed. Decompresses the archive in shell before +# using the mmap'd TextReader. +class ZStdTextReader: + def __init__(self, file) -> None: + self.file = file + + def read_tqdm(self): + decompressed_file = self.file[:-4] + print("Decompressing file, please wait...") + os.system(f"zstd -d {self.file}") # linux decompress is faster + reader = TextReader(decompressed_file) + yield from reader.read_tqdm() + os.remove(decompressed_file) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/decontaminate.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/decontaminate.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1250d39bf7cd0272e412452d970ec7c52992c5 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/decontaminate.py @@ -0,0 +1,166 @@ +import collections +import glob +import json +import os +import pickle +import random +import time + +from .archiver import ZStdTextReader +from .janitor import Janitor, word_ngrams + + +# Was used for testing the evaluator decoupled from the full logic below +def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str): + simulated_overlap = 0.1 + contaminated = int(len(docs) * simulated_overlap) + return random.sample(range(len(docs)), contaminated) + + +# Returns a dictionary containing all overlapping documents in each +# task. In the standard use case, an overlap occurs when any of the 13-grams +# found in the task document exist in the training set documents. +# +# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these +# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst" +# files. These should exist in the "ngrams_path" provided to this function. + + +# Algorithm: +# 1. Build lookups for each dataset {ngram: list(document_ids)} +# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]} +# 3. Full scan the 13-grams from the training set against the merged lookup, +# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)} +# 4. Strip the task_set from the dictionary keys and return +# +# We cache the task+set lookups as well as the overlaps. +def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict: + # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size) + + info_dict_path = os.path.join(ngrams_path, "info.json") + info_dict = json.load(open(info_dict_path, "r", encoding="utf-8")) + ngrams_n_size = info_dict["ngram_size"] + + janitor = Janitor() + + # Build lookup for each dataset first in case we use different task combinations later + print("Building Lookups...") + start = time.perf_counter() + + def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str: + return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps" + + lookups = {} + duplicates = {} # (task_name, task_set): set(doc_ids)} + sets_to_decontaminate = len(docs_by_task_set.keys()) + + for (task_name, task_set), docs in docs_by_task_set.items(): + if not os.path.exists(f"data/{task_name}"): + os.mkdir(f"data/{task_name}") + + # Check if we've decontaminated this combination before + overlaps_dump_path = get_overlaps_dump_path( + task_name, task_set, ngrams_n_size, limit + ) + if os.path.exists(overlaps_dump_path): + duplicates[(task_name, task_set)] = pickle.load( + open(overlaps_dump_path, "rb") + ) + sets_to_decontaminate -= 1 + continue + else: + duplicates[(task_name, task_set)] = set() + + # Build/load the task lookup {ngram: set(documents)}. + task_set_lookup_path = ( + f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup" + ) + if os.path.exists(task_set_lookup_path): + print(f"{task_set_lookup_path} available, loading...") + lookups[(task_name, task_set)] = pickle.load( + open(task_set_lookup_path, "rb") + ) + else: + print(f"{task_set_lookup_path} not available, building...") + lookup = collections.defaultdict(set) + + for doc_id, document in enumerate(docs): + ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size) + for ngram in ngrams: + lookup[ngram].add(doc_id) + + pickle.dump(lookup, open(task_set_lookup_path, "wb")) + lookups[(task_name, task_set)] = lookup + + elapsed = time.perf_counter() - start + print(f"Building lookups took {elapsed:0.5f} seconds.") + + matched_ngrams = [] + + if sets_to_decontaminate > 0: + print("Merging lookups...") + start = time.perf_counter() + merged_lookup = collections.defaultdict(list) + for (task_name, task_set), lookup in lookups.items(): + for ngram, doc_ids in lookup.items(): + merged_lookup[ngram].append((task_name, task_set, doc_ids)) + + elapsed = time.perf_counter() - start + print(f"Merging lookups took {elapsed:0.5f} seconds.") + + print(f"{ngrams_n_size} grams files found in {ngrams_path}:") + files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst")) + print(files) + + for file in files: + start = time.perf_counter() + print(f"Scanning {file}") + reader = ZStdTextReader(file) + total_ngrams = 0 + unique_ngrams = 0 + matching_unique = 0 + non_matching_unique = 0 + + current_ngram = "" + for line in reader.read_tqdm(): # Scan training set ngrams file + total_ngrams += 1 + [ngram, document_id] = line.rsplit(" ", 1) + if ( + ngram != current_ngram + ): # Only need to match the ngram once in training set + unique_ngrams += 1 + current_ngram = ngram + if ngram in merged_lookup: + matched_ngrams.append(ngram) # For logging + matching_unique += 1 + for task_name, task_set, doc_ids in merged_lookup[ngram]: + task_doc_set = duplicates[(task_name, task_set)] + for doc_id in doc_ids: # Record contamination across all relevant task/set combos + task_doc_set.add(doc_id) + del merged_lookup[ngram] # No point matching again + else: + non_matching_unique += 1 + + print(f"Total Ngrams: {total_ngrams}") + print(f"Unique Ngrams: {unique_ngrams}") + print(f"Unique Matching: {matching_unique}") + print(f"Unique Non Matching: {non_matching_unique}") + print("Matched ngrams:") + for ngram in matched_ngrams: + print(ngram) + + elapsed = time.perf_counter() - start + print(f"Read took {elapsed:0.5f} seconds.") + print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second") + + print(duplicates) + + # Dump overlaps separately + for (task_name, task_set), doc_ids in duplicates.items(): + overlaps_dump_path = get_overlaps_dump_path( + task_name, task_set, ngrams_n_size, limit + ) + pickle.dump(doc_ids, open(overlaps_dump_path, "wb")) + + # Strip task set and return + return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()} diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/janitor.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/janitor.py new file mode 100644 index 0000000000000000000000000000000000000000..cedf8a5717aa8156674836ba236fdcabf36e0487 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/decontamination/janitor.py @@ -0,0 +1,328 @@ +import pickle +import re +import string +import traceback +from typing import Iterator, List, Sequence, Tuple, TypeVar + + +# This is a cpp module. Compile janitor_util.cpp with: +# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup +try: + import janitor_util + + JANITOR_CPP = True +except Exception: + print("WARNING: C++ module could not be loaded. Janitor running in python mode") + traceback.print_exc() + JANITOR_CPP = False + +T = TypeVar("T") + + +# Implementation from nltk source +# https://www.nltk.org/_modules/nltk/util.html +def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]: + history = [] + while n > 1: + # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator + try: + next_item = next(sequence) + except StopIteration: + # no more data, terminate the generator + return + history.append(next_item) + n -= 1 + for item in sequence: + history.append(item) + yield tuple(history) + del history[0] + + +def word_ngrams(s: str, n: int) -> Iterator[str]: + """Splits a string into ngram words""" + tokens = s.split() # not a generator :( + ngram_seqs = form_ngrams(iter(tokens), n) + return (" ".join(ngram) for ngram in ngram_seqs) + + +# Does character sequences only - combined faster function to play around with later +# def word_ngrams_indices_combined(sequence, n): +# current_word = "" +# history = [] +# gap = False; +# start = 0 +# end = 0 +# for character in sequence: +# if character == " ": +# if not gap: +# gap = True +# history.append(current_word) +# end += len(current_word) - 1 +# current_word = "" +# if len(history) == n: +# yield (tuple(history), start, end) +# del history[0] +# start = end + 1 +# end = start +# else: +# gap = False +# current_word += character + + +# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python +def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string on whitespaces and records the indices of each in the original string. + @:return generator((word, (start_idx, end_idx)), ...) + """ + return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) + + +def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string into pairs of (ngram words, their start/end indices)""" + tokens_with_indices = split_indices(s) + + # Generator of ngrams of (word, idx_pairs) + # ( + # [(word, (start,end)), (word, (start, end))...], + # [(word, (start, end)), ...], + # ... + # ) + ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n) + + # Generator of pairs of word and index ngrams + # ( + # ([word, word, ...], [(start,end), (start,end), ...]), + # ... + # ) + ngram_indices_pairs = ( + zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices + ) + + # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) + return ( + (" ".join(ngram_seq), (indices[0][0], indices[-1][1])) + for ngram_seq, indices in ngram_indices_pairs + ) + + +class Janitor: + # FIXME delete_chars: Should anything else go here? Special chars? + def __init__( + self, + ngram_n: int = 13, + window_to_remove: int = 200, + too_dirty_cutoff: int = 10, + minimum_slice_length: int = 200, + delete_chars: str = string.punctuation, + ) -> None: + self.ngram_n = ngram_n + self.window_to_remove = window_to_remove + self.too_dirty_cutoff = too_dirty_cutoff + self.minimum_slice_length = minimum_slice_length + self.delete_chars = delete_chars + + self.dirt_ngrams = set() + + # If in python, we'll translate uppercase to lowercase and delete naughty characters. + # This is fast by python standards + # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st + self.translation_table = str.maketrans( + string.ascii_lowercase + string.ascii_uppercase, # These characters + string.ascii_lowercase * 2, # Become these characters + self.delete_chars, # These are deleted + ) + + ############## + # I/O for saving contamination ngrams + ############## + + def save_contamination_ngrams(self, filename: str) -> None: + with open(filename, "wb") as fp: + pickle.dump(filename, fp) + + def load_contamination_ngrams(self, filename: str) -> None: + with open(filename, "rb") as fp: + self.dirt_ngrams = pickle.load(fp) + + ############## + # Call these :) + ############## + + def register_contaminant(self, dirt_string: str) -> None: + """Register a string as contamination to be removed, e.g. a test set + This breaks the dirt_string into ngrams to store for future cleaning""" + if JANITOR_CPP: + return self.register_contaminant_cpp(dirt_string) + else: + print("WARNING: Janitor running in python mode") + return self.register_contaminant_python(dirt_string) + + def clean(self, dirty_string: str) -> List[str]: + """Clean a string (e.g. a training set) by removing all ngrams previously + registered as contaminants. Returns a list of clean chunks, or empty if + the string was too dirty""" + if JANITOR_CPP: + return self.clean_cpp(dirty_string) + else: + print("WARNING: Janitor running in python mode") + return self.clean_python(dirty_string) + + def _split_chunks( + self, dirty_string: str, dirty_parts: Sequence[Tuple] + ) -> List[str]: + clean_chunks = [] + splice_idx = 0 + end = -1 + for i, (ngram, start, end) in enumerate(dirty_parts): + if i >= self.too_dirty_cutoff: + return [] + start = max(0, start - self.window_to_remove) + end = min(len(dirty_string), end + self.window_to_remove) + + if start - splice_idx > self.minimum_slice_length: + clean_chunks.append(dirty_string[splice_idx:start]) + splice_idx = end + + if end < len(dirty_string) - self.minimum_slice_length: + clean_chunks.append(dirty_string[end + 1 :]) + + return clean_chunks + + ############## + # Fast C++ + ############## + + def register_contaminant_cpp(self, dirt_string) -> None: + self.dirt_ngrams.update( + janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) + ) + + def clean_cpp(self, dirty_string: str) -> List[str]: + contamination_indices = janitor_util.clean_ngram_with_indices( + dirty_string, self.delete_chars, self.ngram_n + ) + return self._split_chunks(dirty_string, contamination_indices) + + ############## + # Slow python + ############## + + def normalize_string(self, s: str) -> str: + return s.translate(self.translation_table) + + def register_contaminant_python(self, dirt_string: str) -> None: + self.dirt_ngrams.update( + word_ngrams(self.normalize_string(dirt_string), self.ngram_n) + ) + + def clean_python(self, dirty_string: str) -> List[str]: + contamination_indices = ( + (None, *idx_pair) + for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) + if self.normalize_string(dirty_ngram) in self.dirt_ngrams + ) + return self._split_chunks(dirty_string, contamination_indices) + + +################################################################## +# Tests +################################################################# + +# def print_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 + +# for i in range(1, 10, 2): +# pprint(janitor_util.clean_ngram(source, string.punctuation, i)) +# for ngram, start, end in \ +# janitor_util.clean_ngram_with_indices(source, string.punctuation, i): +# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n")) + + +# def test_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan_python = Janitor() +# jan_cpp = Janitor() + +# jan_python.register_contaminant_python(contaminant) +# jan_cpp.register_contaminant(contaminant) + +# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams) + +# assert jan_python.clean_python(source) == jan_cpp.clean(source), \ +# (jan_python.clean_python(source), jan_cpp.clean(source)) + +# print("Passed test, python==cpp") + + +# def benchmark(): +# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html +# setup = \ +# """ +# with open("data/enwik8", "r") as f: +# data = f.read() +# jan = Janitor(too_dirty_cutoff=1000) +# jan.register_contaminant(''' +# theories is that there is a connection between "geekdom" and autism. +# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled " +# The [[Geek]] Syndrome", which is a point argued by many in the autism rights +# movement{{ref|Wired}}. This article, many professionals assert, is just one example of +# the media's application of mental disease labels to what is actually variant normal behavior +# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual +# interests, even when they seem unusual to others, are not in themselves signs of autism or +# Asperger's syndrome. Others assert that it is actually the medical profession which is applying +# mental disease labels to children who in the past would have simply been accepted as a little +# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue. +# Due to the recent publicity surrounding autism and autis +# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, +# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first +# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties +# would last, took a cautious approach, preferring to save the revenue rather than investing it in +# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential +# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his +# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], +# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M, +# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995), +# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the +# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the +# [[United Arab Emirates]]. After the Emirates gained independence in 1971, +# ''') +# """ + +# n = 1 +# print(f"Timing {n} run on 100 MB") +# print("Register contaminant") +# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n)) + +# print("Clean") +# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n)) + + +# def test_janitor_general(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan = Janitor(ngram_n=3) +# jan.register_contaminant(contaminant) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + +# filename = "data/saved_contam" +# jan.save_contamination_ngrams(filename) + +# jan = Janitor(ngram_n=3) +# jan.load_contamination_ngrams(filename) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + + +# if __name__ == "__main__": +# test() +# # print_cpp() +# # test_cpp() +# # benchmark() diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/evaluator.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2530e45d2e57dcae926e127ea7e074862ae8f9 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/evaluator.py @@ -0,0 +1,765 @@ +import itertools +import json +import logging +import random +import time +from collections import defaultdict +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +import torch + +import dllm_eval.api.metrics +import dllm_eval.api.registry +import dllm_eval.api.task +import dllm_eval.models +from dllm_eval.caching.cache import delete_cache +from dllm_eval.evaluator_utils import ( + consolidate_group_results, + consolidate_results, + get_sample_size, + get_subtask_list, + get_task_list, + prepare_print_tasks, + print_writeout, + run_task_tests, +) +from dllm_eval.loggers import EvaluationTracker +from dllm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash +from dllm_eval.tasks import TaskManager, get_task_dict +from dllm_eval.utils import ( + handle_non_serializable, + hash_string, + positional_deprecated, + setup_logging, + simple_parse_args_string, +) + + +if TYPE_CHECKING: + from dllm_eval.api.model import LM + from dllm_eval.api.task import Task + +eval_logger = logging.getLogger(__name__) + + +@positional_deprecated +def simple_evaluate( + model, + model_args: Optional[Union[str, dict]] = None, + tasks: Optional[List[Union[str, dict, object]]] = None, + num_fewshot: Optional[int] = None, + batch_size: Optional[Union[int, str]] = None, + max_batch_size: Optional[int] = None, + device: Optional[str] = None, + use_cache: Optional[str] = None, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + delete_requests_cache: bool = False, + limit: Optional[Union[int, float]] = None, + samples: Optional[dict] = None, + bootstrap_iters: int = 100000, + check_integrity: bool = False, + write_out: bool = False, + log_samples: bool = True, + evaluation_tracker: Optional[EvaluationTracker] = None, + system_instruction: Optional[str] = None, + apply_chat_template: Union[bool, str] = False, + fewshot_as_multiturn: bool = False, + gen_kwargs: Union[str, dict, None] = None, + task_manager: Optional[TaskManager] = None, + verbosity=None, + predict_only: bool = False, + random_seed: int = 0, + numpy_random_seed: int = 1234, + torch_random_seed: int = 1234, + fewshot_random_seed: int = 1234, + confirm_run_unsafe_code: bool = False, + metadata: Optional[dict] = None, +): + """Instantiate and evaluate a model on a list of tasks. + + :param model: Union[str, LM] + Name of model or LM object, see dllm_eval.models.get_model + :param model_args: Optional[str, dict] + String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object. + Ignored if `model` argument is a LM object. + :param tasks: list[Union[str, dict, Task]] + List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. + :param num_fewshot: int + Number of examples in few-shot context + :param batch_size: int or str, optional + Batch size for model + :param max_batch_size: int, optional + Maximal batch size to try with automatic batch size detection + :param device: str, optional + PyTorch device (e.g. "cpu" or "cuda:0") for running models + :param use_cache: str, optional + A path to a sqlite db file for caching model responses. `None` if not caching. + :param cache_requests: bool, optional + Speed up evaluation by caching the building of dataset requests. `None` if not caching. + :param rewrite_requests_cache: bool, optional + Rewrites all the request cache if set to `True`. `None` if not desired. + :param delete_requests_cache: bool, optional + Deletes all the request cache if set to `True`. `None` if not desired. + :param limit: int or float, optional + Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. + :param samples: dictionary, optional + Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}. + :param bootstrap_iters: + Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed. + :param check_integrity: bool + Whether to run the relevant part of the test suite for the tasks + :param write_out: bool + If True, write out an example document and model input for checking task integrity + :param log_samples: bool + If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis + :param system_instruction: str + System instruction to be applied to the prompt + :param apply_chat_template: Union[bool, str] + Specifies whether to apply a chat template to the prompt. + - If set to True, the default chat template is applied. + - If set to a string, applies the specified chat template by name. + Defaults to False (no chat template applied). + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param gen_kwargs: dict or comma-separated string + Arguments for model generation + Ignored for all tasks with loglikelihood output_type + :param verbosity: str + Verbosity level for logging + :param predict_only: bool + If true only model outputs will be generated and returned. Metrics will not be evaluated + :param random_seed: int + Random seed for python's random module. If set to None, the seed will not be set. + :param numpy_random_seed: int + Random seed for numpy. If set to None, the seed will not be set. + :param torch_random_seed: int + Random seed for torch. If set to None, the seed will not be set. + :param fewshot_random_seed: int + Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None. + :param metadata: dict + Additional metadata to be added to the task manager. Will get passed to the download function of the task. + + return + Dictionary of results + """ + if verbosity is not None: + setup_logging(verbosity=verbosity) + start_date = time.time() + + if limit is not None and samples is not None: + raise ValueError( + "Either 'limit' or 'samples' must be None, but both are not None." + ) + + if ( + (isinstance(model_args, str) and "inst" in model_args.lower()) + or ( + isinstance(model_args, dict) + and any("inst" in str(v).lower() for v in model_args.values()) + ) + ) and not apply_chat_template: + eval_logger.warning( + "Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." + ) + + if delete_requests_cache: + eval_logger.info("Deleting requests cache...") + delete_cache() + + seed_message = [] + if random_seed is not None: + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412 + seed_message.append(f"Setting random seed to {random_seed}") + random.seed(random_seed) + + if numpy_random_seed is not None: + seed_message.append(f"Setting numpy seed to {numpy_random_seed}") + np.random.seed(numpy_random_seed) + + if torch_random_seed is not None: + seed_message.append(f"Setting torch manual seed to {torch_random_seed}") + torch.manual_seed(torch_random_seed) + + if fewshot_random_seed is not None: + seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}") + + if seed_message: + eval_logger.info(" | ".join(seed_message)) + + if tasks is None: + tasks = [] + if len(tasks) == 0: + raise ValueError( + "No tasks specified, or no tasks found. Please verify the task names." + ) + + if gen_kwargs is not None: + if isinstance(gen_kwargs, str): + gen_kwargs = simple_parse_args_string(gen_kwargs) + eval_logger.warning( + f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. " + "Ensure 'do_sample=True' for non-greedy decoding!" + ) + if not gen_kwargs: + gen_kwargs = None + + if isinstance(model, str): + if model_args is None: + eval_logger.warning("model_args not specified. Using defaults.") + model_args = "" + + if isinstance(model_args, dict): + eval_logger.info( + f"Initializing {model} model, with arguments: {model_args}" + ) + lm = dllm_eval.api.registry.get_model(model).create_from_arg_obj( + model_args, + { + "batch_size": batch_size, + "max_batch_size": max_batch_size, + "device": device, + }, + ) + + else: + eval_logger.info( + f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}" + ) + lm = dllm_eval.api.registry.get_model(model).create_from_arg_string( + model_args, + { + "batch_size": batch_size, + "max_batch_size": max_batch_size, + "device": device, + }, + ) + else: + if not isinstance(model, dllm_eval.api.model.LM): + raise TypeError( + f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of dllm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `dllm_eval.models.huggingface.HFLM(pretrained=my_model)` first." + ) + eval_logger.info("Using pre-initialized model") + lm = model + + if use_cache is not None: + eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") + lm = dllm_eval.api.model.CachingLM( + lm, + use_cache + # each rank receives a different cache db. + # necessary to avoid multiple writes to cache at once + + "_rank" + + str(lm.rank) + + ".db", + ) + + if task_manager is None: + metadata = ( + simple_parse_args_string(model_args) + if isinstance(model_args, str) + else model_args + if isinstance(model_args, dict) + else {} + ) | (metadata or {}) + task_manager = TaskManager(metadata=metadata) + + task_dict = get_task_dict( + tasks, + task_manager, + ) + + # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. + # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed) + def _adjust_config(task_dict): + adjusted_task_dict = {} + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + adjusted_task_dict = { + **adjusted_task_dict, + **{task_name: _adjust_config(task_obj)}, + } + + else: + if task_obj.get_config("output_type") == "generate_until": + if gen_kwargs is not None: + task_obj.set_config( + key="generation_kwargs", value=gen_kwargs, update=True + ) + eval_logger.info( + f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}" + ) + + if predict_only: + eval_logger.info( + f"Processing {task_name} in output-only mode. Metrics will not be calculated!" + ) + # we have to change the class properties post-hoc. This is pretty hacky. + task_obj.override_metric(metric_name="bypass") + + # override tasks' fewshot values to the provided num_fewshot arg value + # except if tasks have it set to 0 manually in their configs--then we should never overwrite that + if num_fewshot is not None: + if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: + eval_logger.info( + f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored." + ) + else: + eval_logger.warning( + f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" + ) + task_obj.set_config(key="num_fewshot", value=num_fewshot) + else: + # if num_fewshot not provided, and the task does not define a default one, default to 0 + if ( + default_num_fewshot := task_obj.get_config("num_fewshot") + ) is None: + task_obj.set_config(key="num_fewshot", value=0) + # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) + task_obj.set_fewshot_seed(seed=fewshot_random_seed) + + adjusted_task_dict[task_name] = task_obj + + return adjusted_task_dict + + task_dict = _adjust_config(task_dict) + + if check_integrity: + run_task_tests(task_list=tasks) + + if evaluation_tracker is not None: + evaluation_tracker.general_config_tracker.log_experiment_args( + model_source=model, + model_args=model_args, + system_instruction=system_instruction, + chat_template=lm.chat_template(apply_chat_template) + if apply_chat_template + else None, + fewshot_as_multiturn=fewshot_as_multiturn, + ) + + results = evaluate( + lm=lm, + task_dict=task_dict, + limit=limit, + samples=samples, + cache_requests=cache_requests, + rewrite_requests_cache=rewrite_requests_cache, + bootstrap_iters=bootstrap_iters, + write_out=write_out, + log_samples=True if predict_only else log_samples, + system_instruction=system_instruction, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, + verbosity=verbosity, + confirm_run_unsafe_code=confirm_run_unsafe_code, + ) + if verbosity is not None: + setup_logging(verbosity=verbosity) + + if lm.rank == 0: + if isinstance(model, str): + model_name = model + elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"): + model_name = model.config._name_or_path + else: + model_name = type(model).__name__ + + # add info about the model and few shot config + results["config"] = { + "model": model_name, + "model_args": model_args, + } + # add more detailed model info if available + if isinstance(lm, dllm_eval.models.huggingface.HFLM): + results["config"].update(lm.get_model_info()) + # add info about execution + results["config"].update( + { + "batch_size": batch_size, + "batch_sizes": ( + list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [] + ), + "device": device, + "use_cache": use_cache, + "limit": limit, + "bootstrap_iters": bootstrap_iters, + "gen_kwargs": gen_kwargs, + "random_seed": random_seed, + "numpy_seed": numpy_random_seed, + "torch_seed": torch_random_seed, + "fewshot_seed": fewshot_random_seed, + } + ) + results["git_hash"] = get_git_commit_hash() + results["date"] = start_date + add_env_info(results) # additional environment info to results + add_tokenizer_info(results, lm) # additional info about tokenizer + return results + else: + return None + + +@positional_deprecated +def evaluate( + lm: "LM", + task_dict, + limit: Optional[int] = None, + samples: Optional[dict] = None, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + bootstrap_iters: Optional[int] = 100000, + write_out: bool = False, + log_samples: bool = True, + system_instruction: Optional[str] = None, + apply_chat_template: Union[bool, str] = False, + fewshot_as_multiturn: bool = False, + verbosity: str = "INFO", + confirm_run_unsafe_code: bool = False, +): + """Instantiate and evaluate a model on a list of tasks. + + :param lm: obj + Language Model + :param task_dict: dict[str, Task] + Dictionary of tasks. Tasks will be taken to have name type(task).config.task . + :param limit: int, optional + Limit the number of examples per task (only use this for testing) + :param samples: dictionary, optional + Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}. + :param cache_requests: bool, optional + Speed up evaluation by caching the building of dataset requests. + :param rewrite_requests_cache: bool, optional + Rewrites all the request cache if set to `True`. + :param bootstrap_iters: + Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations. + :param write_out: bool + If True, write out an example document and model input for checking task integrity + :param log_samples: bool + If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis + :param system_instruction: str + System instruction to be applied to the prompt + :param apply_chat_template: Union[bool, str] + Specifies whether to apply a chat template to the prompt. + - If set to True, the default chat template is applied. + - If set to a string, applies the specified chat template by name. + Defaults to False (no chat template applied). + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param verbosity: str + Verbosity level for logging + :param confirm_run_unsafe_code: bool + Whether to confirm running tasks marked as unsafe. + :return + Dictionary of results + """ + + if limit is not None and samples is not None: + raise ValueError( + "Either 'limit' or 'samples' must be None, but both are not None." + ) + if samples is not None: + eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}") + if apply_chat_template: + eval_logger.warning( + "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details." + ) + # tracks all Instances/requests a model must generate output on. + requests = defaultdict(list) + # stores the amount to pad out reqs per req. type so that + # number of fwd passes per distributed rank is equal + padding_requests = defaultdict(int) + + # get lists of group hierarchy and each type of request + eval_tasks = get_task_list(task_dict) + if not log_samples: + if not all( + "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() + for task_output in eval_tasks + ): + raise ValueError("log_samples must be True for 'bypass' metric-only tasks") + + # validation checks: + # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa. + # 2.are we running code that is marked as unsafe. + incompatible_tasks = [] + for task_output in eval_tasks: + task: Task = task_output.task + + if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False): + incompatible_tasks.append(task_output.task_name) + elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code: + raise ValueError( + f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task." + ) + if len(incompatible_tasks) > 0: + if not getattr(lm, "MULTIMODAL", False): + raise ValueError( + f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type." + ) + # end validation check + + # Cache the limit arg. + limit_arg = limit + limits = [] + for task_output in eval_tasks: + task: Task = task_output.task + + limit = get_sample_size(task, limit_arg) + limits.append(limit) + task.build_all_requests( + limit=limit, + samples=samples.get(task_output.task_name, None) + if samples is not None + else samples, + rank=lm.rank, + world_size=lm.world_size, + cache_requests=cache_requests, + rewrite_requests_cache=rewrite_requests_cache, + system_instruction=system_instruction, + apply_chat_template=bool(apply_chat_template), + fewshot_as_multiturn=fewshot_as_multiturn, + chat_template=getattr(lm, "apply_chat_template") + if apply_chat_template + else None, + tokenizer_name=getattr(lm, "tokenizer_name", "") + if apply_chat_template + else "", + ) + eval_logger.debug( + f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" + ) + if write_out: + print_writeout(task) + # aggregate Instances by LM method requested to get output. + for instance in task.instances: + reqtype = instance.request_type + requests[reqtype].append(instance) + + if lm.world_size > 1: + instances_rnk = torch.tensor(len(task._instances), device=lm.device) + gathered_item = ( + lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() + ) + # "multiple_choice" task types dispatch (several) "loglikelihood" request types + reqtype = ( + "loglikelihood" + if task.OUTPUT_TYPE == "multiple_choice" + else task.OUTPUT_TYPE + ) + # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks) + numpad = max(gathered_item) - gathered_item[lm.rank] + # todo: may not account for padding in cases like SquadV2 which has multiple req types + padding_requests[reqtype] += numpad + + ### Run LM on inputs, get all outputs ### + # execute each type of request + for reqtype, reqs in requests.items(): + eval_logger.info(f"Running {reqtype} requests") + # create `K` copies of each request `req` based off `K = req.repeats` + cloned_reqs = [] + for req in reqs: + cloned_reqs.extend([req] * req.repeats) + + if (lm.world_size > 1) and (padding_requests[reqtype] > 0): + for _ in range(padding_requests[reqtype]): + cloned_reqs.extend([req] * req.repeats) + + # run requests through model + resps = getattr(lm, reqtype)(cloned_reqs) + + # put responses from model into a list of length K for each request. + for x, req in zip(resps, cloned_reqs): + req.resps.append(x) + + if lm.world_size > 1: + lm.accelerator.wait_for_everyone() + + RANK = lm.rank + WORLD_SIZE = lm.world_size + ### Postprocess outputs ### + # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) + for task_output, limit in zip(eval_tasks, limits): + task = task_output.task + task.apply_filters() + + ### Collect values of metrics on all datapoints ### + # # unpack results and sort back in order and return control to Task + # TODO: make it possible to use a different metric per filter + # Pre-process task.instances to group by doc_id + instances_by_doc_id = defaultdict(list) + for instance in task.instances: + instances_by_doc_id[instance.doc_id].append(instance) + # Sort instances within each group + for instances in instances_by_doc_id.values(): + instances.sort(key=lambda x: x.idx) + # iterate over different filters used + for filter_key in task.instances[0].filtered_resps.keys(): + indices = ( + samples.get(task_output.task_name, None) + if samples is not None + else None + ) + doc_iterator = task.doc_iterator( + rank=RANK, + limit=limit, + world_size=WORLD_SIZE, + samples=indices, + ) + for doc_id, doc in doc_iterator: + if indices: + doc_id_true = indices[doc_id] + else: + doc_id_true = doc_id + requests = instances_by_doc_id[doc_id] + metrics = task.process_results( + doc, [req.filtered_resps[filter_key] for req in requests] + ) + if log_samples: + target = task.doc_to_target(doc) + example = { + "doc_id": doc_id_true, + "doc": doc, + "target": target, + "arguments": [req.args for req in requests], + "resps": [req.resps for req in requests], + "filtered_resps": [ + req.filtered_resps[filter_key] for req in requests + ], + "filter": filter_key, + "metrics": list(metrics.keys()), + "doc_hash": hash_string( + json.dumps( + requests[0].doc, + indent=2, + default=handle_non_serializable, + ensure_ascii=False, + ) + ), + "prompt_hash": hash_string(requests[0].arguments[0]), + "target_hash": hash_string(str(target)), + } + example.update(metrics) + task_output.logged_samples.append(example) + for metric, value in metrics.items(): + task_output.sample_metrics[(metric, filter_key)].append(value) + + if WORLD_SIZE > 1: + # if multigpu, then gather data across all ranks to rank 0 + # first gather logged samples across all ranks + for task_output in eval_tasks: + if log_samples: + # for task_name, task_samples in list(samples.items()): + full_samples = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.logged_samples, + object_gather_list=full_samples, + dst=0, + ) + + if RANK == 0: + task_output.logged_samples = list( + itertools.chain.from_iterable(full_samples) + ) + + # then collect metrics across all ranks + for metrics in task_output.sample_metrics: + metric_list = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.sample_metrics[metrics], + object_gather_list=metric_list, + dst=0, + ) + if RANK == 0: + task_output.sample_metrics[metrics] = list( + itertools.chain.from_iterable(metric_list) + ) + + if RANK == 0: + ### Aggregate results over all datapoints ### + # aggregate results ; run bootstrap CIs + for task_output in eval_tasks: + task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) + ( + results, + samples, + configs, + versions, + num_fewshot, + higher_is_better, + ) = consolidate_results(eval_tasks) + + ### Calculate group metrics ### + if bool(results): + results, versions, show_group_table, *_ = consolidate_group_results( + results, versions, task_dict + ) + + results_agg, group_agg = prepare_print_tasks(task_dict, results) + subtask_list = get_subtask_list(task_dict) + + # collect all higher_is_better values for metrics + # in the group's subtasks. + # TODO: clean this up ; unify with the below metric_list loop? + _higher_is_better = {} + for group, task_list in subtask_list.items(): + if ( + len(task_list) != 0 + ): # subtask list will list "task_name": [] for solo tasks + for task in task_list: + for m, h in higher_is_better[task].items(): + if m not in _higher_is_better.keys(): + _higher_is_better[m] = h + + if ( + m in _higher_is_better + and _higher_is_better[m] is not None + and _higher_is_better[m] != h + ): + eval_logger.warning( + f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None." + ) + _higher_is_better[m] = None + higher_is_better[group] = _higher_is_better + + results_dict = { + "results": dict(results_agg.items()), + **( + {"groups": dict(group_agg.items())} + if (bool(group_agg) & show_group_table) + else {} + ), + "group_subtasks": dict(reversed(subtask_list.items())), + "configs": dict(sorted(configs.items())), + "versions": dict(sorted(versions.items())), + "n-shot": dict(sorted(num_fewshot.items())), + "higher_is_better": dict(sorted(higher_is_better.items())), + "n-samples": { + task_output.task_name: { + "original": len(task_output.task.eval_docs), + "effective": min( + limit if limit else len(task_output.task.eval_docs), + len(task_output.task.eval_docs), + ), + } + for task_output, limit in zip(eval_tasks, limits) + }, + } + if log_samples: + results_dict["samples"] = dict(samples) + + return results_dict + + else: + return None + + +def request_caching_arg_to_dict(cache_requests: str) -> dict: + request_caching_args = { + "cache_requests": cache_requests in {"true", "refresh"}, + "rewrite_requests_cache": cache_requests == "refresh", + "delete_requests_cache": cache_requests == "delete", + } + + return request_caching_args diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/evaluator_utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/evaluator_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a17950fe606559fb1c3c72fb3e8404759788bbe --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/evaluator_utils.py @@ -0,0 +1,554 @@ +import collections +import logging +import math +import pathlib +import sys +from typing import List, Optional, Tuple, Union + +from dllm_eval.api.group import ConfigurableGroup +from dllm_eval.api.metrics import ( + aggregate_subtask_metrics, + mean, + pooled_sample_stderr, + stderr_for_metric, +) +from dllm_eval.api.task import Task +from dllm_eval.utils import positional_deprecated + + +eval_logger = logging.getLogger(__name__) + + +class TaskOutput: + """ + Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task. + + Attributes: + task (object): The task object. + task_name (str): The name of the task. + task_config (dict): The configuration of the task. + version (str): The version of the task. + group_name (str): The name of the task group. + n_shot (int): The number of shots for the task. + task_alias (str): The alias of the task. + group_alias (str): The alias of the task group. + is_group (bool): Indicates if the task is a group. + logged_samples (list): The list of logged samples. + sample_len (int): The length of the samples. + sample_metrics (defaultdict): The dictionary of samples' metrics. + agg_metrics (defaultdict): The dictionary of aggregate metrics. + + Methods: + from_taskdict(cls, task_name: str, task): + Creates a TaskOutput instance from a task dictionary. + + calculate_aggregate_metric(bootstrap_iters=100000) -> None: + Calculates the aggregate metrics for the task. + """ + + def __init__( + self, + task=None, + task_name=None, + task_config=None, + version=None, + group_name=None, + n_shot=None, + task_alias=None, + group_alias=None, + is_group=None, + ): + self.task = task + self.task_config = task_config + self.task_name = task_name + self.group_name = group_name + self.version = version + self.n_shot = n_shot + self.task_alias = task_alias + self.group_alias = group_alias + self.is_group = is_group + self.logged_samples = [] + self.sample_len = None + self.sample_metrics = collections.defaultdict(list) + self.agg_metrics = collections.defaultdict(list) + + @classmethod + def from_taskdict(cls, task_name: str, task): + if isinstance(task, tuple): + group_name, task = task + else: + group_name = None + if not task: + # these gets filtered out in get_task_list + # once they are added to group hierarchy + is_group = True + return cls( + task=task, task_name=task_name, is_group=is_group, group_name=group_name + ) + version = task.VERSION + task_config = dict(task.dump_config()) + if (n_shot := task_config.get("num_fewshot")) == 0: + n_shot = task_config.get("metadata", {}).get("num_fewshot", 0) + task_alias = task_config.get("alias") + group_alias = task_config.get("group_alias") + return cls( + task=task, + task_name=task_name, + task_config=task_config, + group_name=group_name, + version=version, + n_shot=n_shot, + task_alias=task_alias, + group_alias=group_alias, + ) + + def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None: + for (metric, filter_key), items in self.sample_metrics.items(): + try: + agg_fn = self.task.aggregation()[metric] + except KeyError: + # This is when process results output an arbitrary metric + # TODO: Handle this better and allow other aggregate functions other than mean. + agg_fn = mean + metric_key = f"{metric},{filter_key}" + self.agg_metrics[metric_key] = agg_fn(items) + self.sample_len = len(items) # TODO: same sample size for each metric? + if isinstance(bootstrap_iters, int): + stderr_fn = stderr_for_metric( + metric=agg_fn, + bootstrap_iters=min(bootstrap_iters, 100) + if metric in ["bleu", "chrf", "ter"] + else bootstrap_iters, + ) + self.agg_metrics[f"{metric}_stderr,{filter_key}"] = ( + stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A" + ) + else: + raise ValueError( + f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations." + ) + + def __repr__(self): + return ( + f"TaskOutput(task_name={self.task_name}, " + f"group_name={self.group_name}, " + f"version={self.version}, " + f"n_shot={self.n_shot}, " + f"task_alias={self.task_alias}, " + f"group_alias={self.group_alias})" + ) + + +def get_task_list(task_dict: dict) -> List[TaskOutput]: + outputs = [] + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + _outputs = get_task_list(task_obj) + outputs.extend(_outputs) + else: + task_output = TaskOutput.from_taskdict(task_name, task_obj) + outputs.append(task_output) + + return outputs + + +def get_subtask_list(task_dict, task_root=None, depth=0): + subtask_list = {} + for group_obj, task_obj in task_dict.items(): + if isinstance(group_obj, ConfigurableGroup): + # group_name = group_obj.group_name + group_name = group_obj.group_name + else: + group_name = group_obj + if isinstance(task_obj, dict): + _subtask_list = get_subtask_list( + task_obj, task_root=group_name, depth=depth + 1 + ) + if task_root: + subtask_list.setdefault((task_root, depth), []).extend( + [ + _task + for (_task, _depth) in _subtask_list.keys() + if (_depth - 1) == depth + ] + ) + + subtask_list = {**subtask_list, **_subtask_list} + else: + if isinstance(task_obj, ConfigurableGroup): + # group_or_task_name = task_obj.group_name + group_or_task_name = task_obj.group_name + elif isinstance(task_obj, Task): + # group_or_task_name = task_obj.task_name + group_or_task_name = task_obj.task_name + + if task_root is None: + subtask_list.setdefault((group_or_task_name, depth), []) + else: + subtask_list.setdefault((task_root, depth), []).append( + group_or_task_name + ) + + if depth == 0: + _subtask_list = {} + for group_key, task_list in subtask_list.items(): + group_name, depth = group_key + _subtask_list[group_name] = task_list + subtask_list = _subtask_list + + return subtask_list + + +def print_writeout(task) -> None: + for inst in task.instances: + # print the prompt for the first few documents + if inst.doc_id < 1: + eval_logger.info( + f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\ + \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)" + ) + eval_logger.info(f"Request: {str(inst)}") + + +def get_sample_size(task, limit: Optional[int]) -> Union[int, None]: + if limit is not None: + limit = ( + int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit) + ) + return limit + + +def prepare_print_tasks( + task_dict: dict, + results: dict, + task_depth=0, + group_depth=0, +) -> Tuple[dict, dict]: + """ + @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its + value is a list of task names. + @param results: Dictionary containing the results of each task. Each key is a + group name and its value is a dictionary of task results. + @param task_depth: The indentation level for printing the task + hierarchy. Default is 0. + @param group_depth: The indentation level for printing the group + hierarchy. Default is 0. + @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains + aggregated results for each task, and groups_agg contains aggregated results for each group. + + Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. + """ + + def _sort_task_dict(task_dict): + """ + Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name. + Required so that we end up sorting within each sub-header correctly. + """ + + return dict( + sorted( + task_dict.items(), + key=lambda item: item[0].group_name + if isinstance(item[0], ConfigurableGroup) + else item[0], + ) + ) + + task_agg = collections.defaultdict(dict) + group_agg = collections.defaultdict(dict) + task_dict = _sort_task_dict(task_dict) + for task_or_group_name, task_or_group_obj in task_dict.items(): + tab_string = " " * task_depth + "- " if task_depth > 0 else "" + if isinstance(task_or_group_name, ConfigurableGroup): + # string_name = task_or_group_name.group_name + name = task_or_group_name.group_name + from_configurable_group = True + task_or_group_obj = _sort_task_dict(task_or_group_obj) + elif isinstance(task_or_group_name, str): + name = task_or_group_name + if isinstance(task_or_group_obj, Task): + # string_name = task_or_group_obj.task_name + name = task_or_group_obj.task_name + from_configurable_group = False + + task_agg[name] = results[name].copy() + if from_configurable_group: + if task_or_group_name.group_alias is not None: + alias = task_or_group_name.group_alias + else: + alias = task_or_group_name.group + else: + if "alias" in task_agg[name]: + alias = task_agg[name]["alias"] + else: + alias = name + + task_agg[name]["alias"] = tab_string + alias + if "samples" in task_agg[name]: + task_agg[name].pop("samples") + + if from_configurable_group and (" " not in results[name]): + group_tab_string = " " * group_depth + "- " if group_depth > 0 else "" + group_agg[name] = results[name].copy() + group_agg[name]["alias"] = group_tab_string + alias + if "samples" in group_agg[name]: + group_agg[name].pop("samples") + + if isinstance(task_or_group_obj, dict): + task_depth += 1 + group_depth += 1 + _task_agg, _group_agg = prepare_print_tasks( + task_or_group_obj, results, task_depth, group_depth + ) + task_agg = { + **task_agg, + **_task_agg, + } + group_agg = {**group_agg, **_group_agg} + task_depth -= 1 + group_depth -= 1 + return task_agg, group_agg + + +def consolidate_results( + eval_tasks: List[TaskOutput], +) -> Tuple[dict, dict, dict, dict, dict, dict]: + """ + @param eval_tasks: list(TaskOutput). + @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot. + + Consolidates the results of multiple evaluation tasks into a single structure. + + The method iterates over each evaluation instance and extracts relevant information to create the consolidated + results structure. The consolidated results structure has the following properties: + + - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains + metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task + aliases specified in the task configuration. + - samples: A defaultdict with task names as keys and lists of log samples as values. + - configs: A defaultdict with task names as keys and task configurations as values. + - versions: A defaultdict with task names as keys and task versions as values. + - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values. + - higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better + for each metric as values. + + The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple. + """ + # stores the final result for each task, for each metric/filter pair. + results = collections.defaultdict(dict) + # logs info about each document evaluated. + samples = collections.defaultdict(list) + # store num-fewshot value per task + num_fewshot = collections.defaultdict(int) + # Tracks the YAML configs of all chosen task + configs = collections.defaultdict(dict) + # Tracks each task's version. + versions = collections.defaultdict(dict) + # Track `higher_is_better` for each metric + higher_is_better = collections.defaultdict(dict) + + for task_output in eval_tasks: + if "task_alias" in (task_config := task_output.task_config): + results[task_output.task_name]["alias"] = task_config["task_alias"] + else: + results[task_output.task_name]["alias"] = task_output.task_name + if group_alias := task_output.group_alias: + if group_alias not in results and (group_name := task_output.group_name): + results[group_name]["alias"] = group_alias + num_fewshot[task_output.task_name] = task_output.n_shot + configs[task_output.task_name] = task_output.task_config + versions[task_output.task_name] = task_output.version + samples[task_output.task_name] = task_output.logged_samples + higher_is_better[task_output.task_name] = task_output.task.higher_is_better() + for (metric, filter_key), items in task_output.sample_metrics.items(): + metric_key = f"{metric},{filter_key}" + results[task_output.task_name][metric_key] = task_output.agg_metrics[ + metric_key + ] + results[task_output.task_name]["samples"] = task_output.sample_len + results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = ( + task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] + ) + return results, samples, configs, versions, num_fewshot, higher_is_better + + +def consolidate_group_results( + results, + versions, + task_dict, + task_root=None, + show_group_table=False, + task_aggregation_list=None, +) -> Tuple[dict, dict, bool, Union[None,]]: + """ + (Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info. + + @return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below: + + - results: A defaultdict with task names (and, after this function is called, group names of + groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys. + - versions: A defaultdict with task names (and, after this function is called, group names of + groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None). + - show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table. + - task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric. + + The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple. + In the top-level invocation of this function, task_aggregation_list is ignored. + """ + if task_root is None: + task_root = {} + + if task_aggregation_list is None: + task_aggregation_list = {} + + for group_or_task, group_or_task_info in task_dict.items(): + # Convert to string + if isinstance(group_or_task, ConfigurableGroup): + group_config = group_or_task.config + group_or_task = group_or_task.group_name + else: + group_config = None + + if isinstance(group_or_task_info, Task): + if task_root: + task_aggregation_list.setdefault(task_root, []).append( + group_or_task_info.task_name + ) + else: + ( + results, + versions, + show_group_table, + _task_aggregation_list, + ) = consolidate_group_results( + results, + versions, + group_or_task_info, + group_or_task, + show_group_table, + task_aggregation_list, + ) + if task_root: + task_aggregation_list.setdefault(task_root, []).extend( + task_aggregation_list.get(group_or_task, []) + ) + + if (group_config is None) or ( + group_config["aggregate_metric_list"] is None + ): + results[group_or_task][" "] = " " + continue + + if "aggregate_metric_list" in group_config: + agg_metric_list = group_config["aggregate_metric_list"] + + show_group_table = show_group_table | bool( + group_config["aggregate_metric_list"] + ) + + task_list = _task_aggregation_list[group_or_task] + + metric_list = list( + { + key + for task in task_list + for key in results[task].keys() + if "_stderr" not in key and key not in ["task", "alias", "samples"] + } + ) + for metric in metric_list: + stderr = "_stderr,".join(metric.split(",")) + + # gather metrics, sizes, and stderrs from subtasks + metrics = [ + results[task][metric] + for task in task_list + if metric in results[task] + ] # TODO: copy? + stderrs = [ + results[task][stderr] + for task in task_list + if stderr in results[task] + ] + sizes = [ + results[task]["samples"] + for task in task_list + if metric in results[task] + ] + + for metric_config in agg_metric_list: + for filter_name in metric_config["filter_list"]: + if metric != ",".join([metric_config["metric"], filter_name]): + continue + + # compute group's pooled metric and stderr + if metric_config["aggregation"] == "mean": + aggregate_fn = aggregate_subtask_metrics + elif callable(metric_config["aggregation"]): + aggregate_fn = metric_config["aggregation"] + else: + raise ValueError( + f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'" + ) + + results[group_or_task][metric] = aggregate_fn( + metrics, + sizes, + metric_config["weight_by_size"], + ) + # TODO: calculate groups' metrics using arbitrary agg fns + if "N/A" in stderrs: + results[group_or_task][stderr] = "N/A" + else: + # NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere + results[group_or_task][stderr] = pooled_sample_stderr( + stderrs, sizes + ) + + results[group_or_task]["samples"] = sum(sizes) + group_metadata = group_config.get("metadata", None) + if group_metadata is not None: + versions[group_or_task] = group_metadata.get("version", None) + # print(results) + return results, versions, show_group_table, task_aggregation_list + + +@positional_deprecated +def find_test_root(start_path: pathlib.Path) -> pathlib.Path: + """ + Search upward in the directory tree to a maximum of three layers + to find and return the package root (containing the 'tests' folder) + """ + cur_path = start_path.resolve() + max_layers = 3 + for _ in range(max_layers): + if (cur_path / "tests" / "test_version_stable.py").exists(): + return cur_path + else: + cur_path = cur_path.parent.resolve() + raise FileNotFoundError( + f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" + ) + + +@positional_deprecated +def run_task_tests(task_list: List[str]): + """ + Find the package root and run the tests for the given tasks + """ + import pytest + + package_root = find_test_root(start_path=pathlib.Path(__file__)) + task_string = " or ".join(task_list) + args = [ + f"{package_root}/tests/test_version_stable.py", + f"--rootdir={package_root}", + "-k", + f"{task_string}", + ] + sys.path.append(str(package_root)) + pytest_return_val = pytest.main(args) + if pytest_return_val: + raise ValueError( + f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8911d26c34cc07d1c92d20b904f48ef6fcce8ea4 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/__init__.py @@ -0,0 +1,25 @@ +from functools import partial +from typing import List + +from dllm_eval.api.filter import FilterEnsemble +from dllm_eval.api.registry import get_filter + +from . import custom, extraction, selection, transformation + + +def build_filter_ensemble( + filter_name: str, components: List[List[str]] +) -> FilterEnsemble: + """ + Create a filtering pipeline. + """ + filters = [] + for function, kwargs in components: + if kwargs is None: + kwargs = {} + # create a filter given its name in the registry + f = partial(get_filter(function), **kwargs) + # add the filter as a pipeline step + filters.append(f) + + return FilterEnsemble(name=filter_name, filters=filters) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/custom.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..07576f8a503f816de42ca1a80729edb517d75a5c --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/custom.py @@ -0,0 +1,17 @@ +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("custom") +class CustomFilter(Filter): + """ + Custom filter that applies a custom, user-defined function to the model responses. + """ + + def __init__(self, **kwargs) -> None: + self.filter_fn = kwargs.pop("filter_fn") + + super().__init__(**kwargs) + + def apply(self, resps, docs): + return self.filter_fn(resps, docs) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/decontamination.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/decontamination.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4ff15a2d856a0a191aaeb5288c3706275dddd8 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/decontamination.py @@ -0,0 +1,25 @@ +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("decontaminate") +class DecontaminationFilter(Filter): + """ + A filter which evaluates + """ + + name = "track_decontamination" + + def __init__(self, path) -> None: + """ + + TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). + should further cache result on a given (task_name, doc_id) + """ + self._decontam_results = None + + def apply(self, resps, docs) -> None: + """ + Return {"no_contamination", "only_contamination"} keys for the 2 different subsets + """ + pass diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/extraction.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..3998e7c463e5f75cff6ed19c135441cc40ba3c8b --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/extraction.py @@ -0,0 +1,233 @@ +import re +import sys +import unicodedata + +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("regex") +class RegexFilter(Filter): + """A filter that extracts values from text using regex pattern matching. + + This filter applies a regex pattern to each model response and extracts matched values. + If no match is found, returns a fallback value. Useful for extracting structured data + (like numbers) from unstructured model outputs. + """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select: int = 0, + fallback: str = "[invalid]", + ) -> None: + """ + pass a string `regex` to run `re.compile(r"regex")` on. + `fallback` defines the output returned if no matches for the regex are located. + """ + self.regex_pattern = regex_pattern + self.regex = re.compile(regex_pattern) + self.group_select = group_select + self.fallback = fallback + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + def filter_set(inst): + filtered = [] + for resp in inst: + match = self.regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m] + if match: + match = match[0] + else: + match = self.fallback + match = match.strip() + else: + match = self.fallback + filtered.append(match) + return filtered + + filtered_resps = list(map(lambda x: filter_set(x), resps)) + return filtered_resps + + +@register_filter("regex_pos") +class POSFilter(Filter): + """ """ + + def __init__( + self, + regex_pattern: str = r"\['(.*?)'\]", + group_select=0, + fallback=None, + ) -> None: + """ + pass a string `regex` to run `re.compile(r"regex")` on. + `fallback` defines the output returned if no matches for the regex are located. + """ + if fallback is None: + fallback = ["invalid"] + self.regex_pattern = regex_pattern + self.regex = re.compile(regex_pattern) + self.group_select = group_select + self.fallback = fallback + + def apply(self, resps, docs): + def extract_tagged_tokens(text): + # Extract tagged tokens list from text input using regex + tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text) + return [(token, pos) for token, pos in tokens] + + def extract_pos_tags(result): + pos_tags = [] + if isinstance(result, str): + result = extract_tagged_tokens(result) + pos_tags.extend(pos for _, pos in result) + return pos_tags if pos_tags else self.fallback + + def filter_set(inst): + filtered = [] + for resp in inst: + match = extract_pos_tags(resp) + filtered.append(match) + return filtered + + filtered_resps = map(lambda x: filter_set(x), resps) + + return filtered_resps + + +@register_filter("remove_whitespace") +class WhitespaceFilter(Filter): + """Filters out leading whitespace from responses.""" + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + def filter_set(inst): + filtered_resp = [] + for resp in inst: + resp = resp.lstrip() + filtered_resp.append(resp) + return filtered_resp + + filtered_resps = [filter_set(resp) for resp in resps] + + return filtered_resps + + +@register_filter("multi_choice_regex") +class MultiChoiceRegexFilter(RegexFilter): + """ + A filter used to extract a model's answer on multiple choice questions with + letter answers. assumes each document has a "choices" field + containing the list of answer choices and that the answer label symbols + are of the form (A), (B), (C), ... or A, B, C. + """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, + ) -> None: + """ + regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure + - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. + - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. + group_select: Selects the (group_select)th match from the findall result. + ignore_case: Ignores the case during step 1 matching + ignore_punctuation: Remove the punctuation during step 1 matching + regexes_to_ignore: Remove these regexes during step 1 matching + """ + super().__init__(regex_pattern, group_select, fallback) + self.ignore_case = ignore_case + self.ignore_punctuation = ignore_punctuation + self.regexes_to_ignore = regexes_to_ignore + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + + def find_match(regex, resp, convert_dict={}): + match = regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m][0] + match = match.strip() + if match and match in convert_dict: + match = convert_dict[match] + return match + + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) + + def filter_ignores(st): + if self.regexes_to_ignore is not None: + for s in self.regexes_to_ignore: + st = re.sub(s, "", st) + + if self.ignore_case: + st = st.lower() + + if self.ignore_punctuation: + # https://stackoverflow.com/a/266162 + st = st.translate(punct_tbl) + return st + + filtered_resps = [] + + for r, doc in zip(resps, docs): + fallback_regexes = [] + choice_to_alpha = {} + next_alpha = "A" + + without_paren_fallback_regexes = [] + without_paren_to_target = {} + + choices = doc["choices"] + for c in choices: + m = filter_ignores(c.strip()) + fallback_regexes.append(f"{re.escape(m)}") + choice_to_alpha[m] = f"({next_alpha})" + + without_paren_fallback_regexes.append(next_alpha) + without_paren_to_target[next_alpha] = f"({next_alpha})" + + next_alpha = chr(ord(next_alpha) + 1) + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + rf":[\s]*({without_paren_fallback_regex})" + ) + + filtered = [] + for resp in r: + match = find_match(self.regex, resp) + if not match: + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) + if not match: + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) + if not match: + match = self.fallback + filtered.append(match) + filtered_resps.append(filtered) + + return filtered_resps diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/selection.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/selection.py new file mode 100644 index 0000000000000000000000000000000000000000..47b9c9bc71f254c91ba92aa8578b8c9f8cb3341f --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/selection.py @@ -0,0 +1,61 @@ +from collections import Counter + +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function +# that takes an input and returns a scalar and then should select the max reward, +# or should implement different filters for different ways of handling a reward model's inference. + + +@register_filter("take_first") +class TakeFirstFilter(Filter): + def __init__(self) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + def apply(self, resps, docs): + """ + Assuming each entry of `resps` is a list of model responses, we discard all but the first response. + """ + return map(lambda r: r[0], resps) + + +@register_filter("take_first_k") +class TakeKFilter(Filter): + def __init__(self, **kwargs) -> None: + self.k = kwargs.pop("k") + + super().__init__(**kwargs) + + def apply(self, resps, docs): + # need resp to be subscriptable to check below + resps = list(resps) + # check we have at least k responses per doc, else we can't take the first k + assert len(resps[0]) >= self.k, ( + f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." + ) + return map(lambda r: r[: self.k], resps) + + +@register_filter("majority_vote") +class MajorityVoteFilter(Filter): + def __init__(self) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + def apply(self, resps, docs): + """ + Each entry of `resps` is a list of model responses. + We select the response that occurs most frequently in each entry of `resps`. + """ + + def select_majority(resp): + counts = Counter(resp) + vote = counts.most_common(1)[0][0] + return vote + + return map(lambda r: [select_majority(r)], resps) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/transformation.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..48d2a21d7d510991977ebcf6601c2e7437ecb4bb --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/filters/transformation.py @@ -0,0 +1,122 @@ +import re + +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("lowercase") +class LowercaseFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def filter_set(inst): + return [resp.lower() for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("uppercase") +class UppercaseFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def filter_set(inst): + return [resp.upper() for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("map") +class MapFilter(Filter): + def __init__(self, mapping_dict: dict = None, default_value=None) -> None: + """ + Initializes the MapFilter with a given mapping dictionary and default value. + + Args: + - mapping_dict (dict): A dictionary containing the key-value mappings. + Default is an empty dictionary. + - default_value (Any): The value to be returned when a key is not found in the mapping_dict. + Default is None. + + Example: + mapper = MapFilter({'A': 1, 'B': 2}, default_value=0) + """ + if mapping_dict is None: + mapping_dict = {} + assert isinstance(mapping_dict, dict), ( + "Provided mapping_dict is not a dictionary" + ) + self.mapping_dict = mapping_dict + self.default_value = default_value + + def apply(self, resps, docs): + def filter_set(inst): + return [self.mapping_dict.get(resp, self.default_value) for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("format_span") +class SPANFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def format_ner_text(text): + label_dict = { + "person": "PER", + "location": "LOC", + "organization": "ORG", + "counties": "LOC", + "places": "LOC", + "people": "PER", + "persons": "PER", + "company": "ORG", + "country": "LOC", + "continent": "LOC", + "time": "DATE", + "date": "DATE", + "per": "PER", + "loc": "LOC", + "org": "ORG", + } + text = text.lower() + for key, value in label_dict.items(): + text = text.replace(key, value) + + text = "$".join(i for i in text.split("$$")) + return text.rstrip("$$") + + def format_named_entities(text): + """ + Extract named entities from text and format them as 'label: value $$ label: value'. + Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values. + """ + # Regular expression to match label: entities pattern + pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)" + # Normalize newline characters + text = text.replace("\n", "$").strip() + matches = re.findall(pattern, text) + + formatted_entities = [] + + for label, values in matches: + # Split multiple entities separated by commas and strip whitespace + entities = [value.strip() for value in values.split(",")] + + # Exclude 'none' entities + for entity in entities: + if entity.lower() != "none": + formatted_entities.append(f"{label.lower()}: {entity}") + + # Join entities with the desired separator + return " $ ".join(formatted_entities) + + def filter_set(inst): + return [ + format_named_entities(format_ner_text(resp.lower())) for resp in inst + ] + + return [filter_set(resp) for resp in resps] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02b7a6834c6486fde35ef02d715e90be3fba223a --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/__init__.py @@ -0,0 +1,2 @@ +from .evaluation_tracker import EvaluationTracker +from .wandb_logger import WandbLogger diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/evaluation_tracker.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/evaluation_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7f88978e73a8fad88d83a9563e85090b8c7e5594 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/evaluation_tracker.py @@ -0,0 +1,530 @@ +import json +import logging +import os +import re +import time +from collections import defaultdict +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path + +from datasets import load_dataset +from datasets.utils.metadata import MetadataConfigs +from huggingface_hub import ( + DatasetCard, + DatasetCardData, + HfApi, + hf_hub_url, +) +from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status + +from dllm_eval.utils import ( + get_file_datetime, + get_file_task_name, + get_results_filenames, + get_sample_results_filenames, + handle_non_serializable, + hash_string, + sanitize_list, + sanitize_model_name, + sanitize_task_name, +) + + +eval_logger = logging.getLogger(__name__) + + +@dataclass(init=False) +class GeneralConfigTracker: + """ + Tracker for the evaluation parameters. + + Attributes: + model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.) + model_name (str): Name of the model. + model_name_sanitized (str): Sanitized model name for directory creation. + start_time (float): Start time of the experiment. Logged at class init. + end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`] + total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times). + """ + + model_source: str = None + model_name: str = None + model_name_sanitized: str = None + system_instruction: str = None + system_instruction_sha: str = None + fewshot_as_multiturn: bool = None + chat_template: str = None + chat_template_sha: str = None + start_time: float = None + end_time: float = None + total_evaluation_time_seconds: str = None + + def __init__(self) -> None: + """Starts the evaluation timer.""" + self.start_time = time.perf_counter() + + @staticmethod + def _get_model_name(model_args: str) -> str: + """Extracts the model name from the model arguments.""" + + def extract_model_name(model_args: str, key: str) -> str: + """Extracts the model name from the model arguments using a key.""" + args_after_key = model_args.split(key)[1] + return args_after_key.split(",")[0] + + # order does matter, e.g. peft and delta are provided together with pretrained + prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="] + for prefix in prefixes: + if prefix in model_args: + return extract_model_name(model_args, prefix) + return "" + + def log_experiment_args( + self, + model_source: str, + model_args: str, + system_instruction: str, + chat_template: str, + fewshot_as_multiturn: bool, + ) -> None: + """Logs model parameters and job ID.""" + self.model_source = model_source + self.model_name = GeneralConfigTracker._get_model_name(model_args) + self.model_name_sanitized = sanitize_model_name(self.model_name) + self.system_instruction = system_instruction + self.system_instruction_sha = ( + hash_string(system_instruction) if system_instruction else None + ) + self.chat_template = chat_template + self.chat_template_sha = hash_string(chat_template) if chat_template else None + self.fewshot_as_multiturn = fewshot_as_multiturn + + def log_end_time(self) -> None: + """Logs the end time of the evaluation and calculates the total evaluation time.""" + self.end_time = time.perf_counter() + self.total_evaluation_time_seconds = str(self.end_time - self.start_time) + + +class EvaluationTracker: + """ + Keeps track and saves relevant information of the evaluation process. + Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested. + """ + + def __init__( + self, + output_path: str = None, + hub_results_org: str = "", + hub_repo_name: str = "", + details_repo_name: str = "", + results_repo_name: str = "", + push_results_to_hub: bool = False, + push_samples_to_hub: bool = False, + public_repo: bool = False, + token: str = "", + leaderboard_url: str = "", + point_of_contact: str = "", + gated: bool = False, + ) -> None: + """ + Creates all the necessary loggers for evaluation tracking. + + Args: + output_path (str): Path to save the results. If not provided, the results won't be saved. + hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token. + hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`. + details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`. + result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo. + push_results_to_hub (bool): Whether to push the results to the Hugging Face hub. + push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub. + public_repo (bool): Whether to push the results to a public or private repository. + token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`. + leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card. + point_of_contact (str): Contact information on the Hugging Face hub dataset card. + gated (bool): Whether to gate the repository. + """ + self.general_config_tracker = GeneralConfigTracker() + + self.output_path = output_path + self.push_results_to_hub = push_results_to_hub + self.push_samples_to_hub = push_samples_to_hub + self.public_repo = public_repo + self.leaderboard_url = leaderboard_url + self.point_of_contact = point_of_contact + self.api = HfApi(token=token) if token else None + self.gated_repo = gated + + if not self.api and (push_results_to_hub or push_samples_to_hub): + raise ValueError( + "Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. " + "Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable." + ) + + if ( + self.api + and hub_results_org == "" + and (push_results_to_hub or push_samples_to_hub) + ): + hub_results_org = self.api.whoami()["name"] + eval_logger.warning( + f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'." + ) + + if hub_repo_name == "": + details_repo_name = ( + details_repo_name if details_repo_name != "" else "lm-eval-results" + ) + results_repo_name = ( + results_repo_name if results_repo_name != "" else details_repo_name + ) + else: + details_repo_name = hub_repo_name + results_repo_name = hub_repo_name + eval_logger.warning( + "hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead." + ) + + self.details_repo = f"{hub_results_org}/{details_repo_name}" + self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private" + self.results_repo = f"{hub_results_org}/{results_repo_name}" + self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private" + + def save_results_aggregated( + self, + results: dict, + samples: dict, + ) -> None: + """ + Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested. + + Args: + results (dict): The aggregated results to save. + samples (dict): The samples results to save. + """ + self.general_config_tracker.log_end_time() + + if self.output_path: + try: + eval_logger.info("Saving results aggregated") + + # calculate cumulative hash for each task - only if samples are provided + task_hashes = {} + if samples: + for task_name, task_samples in samples.items(): + sample_hashes = [ + s["doc_hash"] + s["prompt_hash"] + s["target_hash"] + for s in task_samples + ] + task_hashes[task_name] = hash_string("".join(sample_hashes)) + + # update initial results dict + results.update({"task_hashes": task_hashes}) + results.update(asdict(self.general_config_tracker)) + dumped = json.dumps( + results, + indent=2, + default=handle_non_serializable, + ensure_ascii=False, + ) + + path = Path(self.output_path if self.output_path else Path.cwd()) + self.date_id = datetime.now().isoformat().replace(":", "-") + if path.suffix == ".json": + path.parent.mkdir(parents=True, exist_ok=True) + file_results_aggregated = path.with_name( + f"{path.stem}_{self.date_id}.json" + ) + else: + path.mkdir(parents=True, exist_ok=True) + file_results_aggregated = path.joinpath( + f"results_{self.date_id}.json" + ) + + file_results_aggregated.open("w", encoding="utf-8").write(dumped) + + if self.api and self.push_results_to_hub: + repo_id = ( + self.results_repo + if self.public_repo + else self.results_repo_private + ) + self.api.create_repo( + repo_id=repo_id, + repo_type="dataset", + private=not self.public_repo, + exist_ok=True, + ) + self.api.upload_file( + repo_id=repo_id, + path_or_fileobj=str(file_results_aggregated), + path_in_repo=os.path.join( + self.general_config_tracker.model_name, + file_results_aggregated.name, + ), + repo_type="dataset", + commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}", + ) + eval_logger.info( + "Successfully pushed aggregated results to the Hugging Face Hub. " + f"You can find them at: {repo_id}" + ) + + except Exception as e: + eval_logger.warning("Could not save results aggregated") + eval_logger.info(repr(e)) + else: + eval_logger.info( + "Output path not provided, skipping saving results aggregated" + ) + + def save_results_samples( + self, + task_name: str, + samples: dict, + ) -> None: + """ + Saves the samples results to the output path and pushes them to the Hugging Face hub if requested. + + Args: + task_name (str): The task name to save the samples for. + samples (dict): The samples results to save. + """ + if self.output_path: + try: + eval_logger.info(f"Saving per-sample results for: {task_name}") + + path = Path(self.output_path if self.output_path else Path.cwd()) + if path.suffix == ".json": + path = path.parent + path.mkdir(parents=True, exist_ok=True) + + file_results_samples = path.joinpath( + f"samples_{task_name}_{self.date_id}.jsonl" + ) + + for sample in samples: + # we first need to sanitize arguments and resps + # otherwise we won't be able to load the dataset + # using the datasets library + arguments = {} + for i, arg in enumerate(sample["arguments"]): + arguments[f"gen_args_{i}"] = {} + for j, tmp in enumerate(arg): + arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp + + sample["resps"] = sanitize_list(sample["resps"]) + sample["filtered_resps"] = sanitize_list(sample["filtered_resps"]) + sample["arguments"] = arguments + sample["target"] = str(sample["target"]) + + sample_dump = ( + json.dumps( + sample, + default=handle_non_serializable, + ensure_ascii=False, + ) + + "\n" + ) + + with open(file_results_samples, "a", encoding="utf-8") as f: + f.write(sample_dump) + + if self.api and self.push_samples_to_hub: + repo_id = ( + self.details_repo + if self.public_repo + else self.details_repo_private + ) + self.api.create_repo( + repo_id=repo_id, + repo_type="dataset", + private=not self.public_repo, + exist_ok=True, + ) + try: + if self.gated_repo: + headers = build_hf_headers() + r = get_session().put( + url=f"https://huggingface.co/api/datasets/{repo_id}/settings", + headers=headers, + json={"gated": "auto"}, + ) + hf_raise_for_status(r) + except Exception as e: + eval_logger.warning("Could not gate the repository") + eval_logger.info(repr(e)) + self.api.upload_folder( + repo_id=repo_id, + folder_path=str(path), + path_in_repo=self.general_config_tracker.model_name_sanitized, + repo_type="dataset", + commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}", + ) + eval_logger.info( + f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. " + f"You can find them at: {repo_id}" + ) + + except Exception as e: + eval_logger.warning("Could not save sample results") + eval_logger.info(repr(e)) + else: + eval_logger.info("Output path not provided, skipping saving sample results") + + def recreate_metadata_card(self) -> None: + """ + Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub. + """ + + eval_logger.info("Recreating metadata card") + repo_id = self.details_repo if self.public_repo else self.details_repo_private + + files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") + results_files = get_results_filenames(files_in_repo) + sample_files = get_sample_results_filenames(files_in_repo) + + # Build a dictionary to store the latest evaluation datetime for: + # - Each tested model and its aggregated results + # - Each task and sample results, if existing + # i.e. { + # "org__model_name__gsm8k": "2021-09-01T12:00:00", + # "org__model_name__ifeval": "2021-09-01T12:00:00", + # "org__model_name__results": "2021-09-01T12:00:00" + # } + latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat()) + + for file_path in sample_files: + file_path = Path(file_path) + filename = file_path.name + model_name = file_path.parent + task_name = get_file_task_name(filename) + results_datetime = get_file_datetime(filename) + task_name_sanitized = sanitize_task_name(task_name) + # Results and sample results for the same model and task will have the same datetime + samples_key = f"{model_name}__{task_name_sanitized}" + results_key = f"{model_name}__results" + latest_datetime = max( + latest_task_results_datetime[samples_key], + results_datetime, + ) + latest_task_results_datetime[samples_key] = latest_datetime + latest_task_results_datetime[results_key] = max( + latest_task_results_datetime[results_key], + latest_datetime, + ) + + # Create metadata card + card_metadata = MetadataConfigs() + + # Add the latest aggregated results to the metadata card for easy access + for file_path in results_files: + file_path = Path(file_path) + results_filename = file_path.name + model_name = file_path.parent + eval_date = get_file_datetime(results_filename) + eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date) + results_filename = Path("**") / Path(results_filename).name + config_name = f"{model_name}__results" + sanitized_last_eval_date_results = re.sub( + r"[^\w\.]", "_", latest_task_results_datetime[config_name] + ) + + if eval_date_sanitized == sanitized_last_eval_date_results: + # Ensure that all results files are listed in the metadata card + current_results = card_metadata.get(config_name, {"data_files": []}) + current_results["data_files"].append( + {"split": eval_date_sanitized, "path": [str(results_filename)]} + ) + card_metadata[config_name] = current_results + # If the results file is the newest, update the "latest" field in the metadata card + card_metadata[config_name]["data_files"].append( + {"split": "latest", "path": [str(results_filename)]} + ) + + # Add the tasks details configs + for file_path in sample_files: + file_path = Path(file_path) + filename = file_path.name + model_name = file_path.parent + task_name = get_file_task_name(filename) + eval_date = get_file_datetime(filename) + task_name_sanitized = sanitize_task_name(task_name) + eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date) + results_filename = Path("**") / Path(filename).name + config_name = f"{model_name}__{task_name_sanitized}" + sanitized_last_eval_date_results = re.sub( + r"[^\w\.]", "_", latest_task_results_datetime[config_name] + ) + if eval_date_sanitized == sanitized_last_eval_date_results: + # Ensure that all sample results files are listed in the metadata card + current_details_for_task = card_metadata.get( + config_name, {"data_files": []} + ) + current_details_for_task["data_files"].append( + {"split": eval_date_sanitized, "path": [str(results_filename)]} + ) + card_metadata[config_name] = current_details_for_task + # If the samples results file is the newest, update the "latest" field in the metadata card + card_metadata[config_name]["data_files"].append( + {"split": "latest", "path": [str(results_filename)]} + ) + + # Get latest results and extract info to update metadata card examples + latest_datetime = max(latest_task_results_datetime.values()) + latest_model_name = max( + latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k] + ) + last_results_file = [ + f for f in results_files if latest_datetime.replace(":", "-") in f + ][0] + last_results_file_path = hf_hub_url( + repo_id=repo_id, filename=last_results_file, repo_type="dataset" + ) + latest_results_file = load_dataset( + "json", data_files=last_results_file_path, split="train" + ) + results_dict = latest_results_file["results"][0] + new_dictionary = {"all": results_dict} + new_dictionary.update(results_dict) + results_string = json.dumps(new_dictionary, indent=4) + + dataset_summary = ( + "Dataset automatically created during the evaluation run of model " + ) + if self.general_config_tracker.model_source == "hf": + dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n" + else: + dataset_summary += f"{self.general_config_tracker.model_name}\n" + dataset_summary += ( + f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n" + f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each " + 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n' + 'An additional configuration "results" store all the aggregated results of the run.\n\n' + "To load the details from a run, you can for instance do the following:\n" + ) + if self.general_config_tracker.model_source == "hf": + dataset_summary += ( + "```python\nfrom datasets import load_dataset\n" + f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n' + ) + dataset_summary += ( + "## Latest results\n\n" + f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) " + "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. " + 'You find each in the results and the "latest" split for each eval):\n\n' + f"```python\n{results_string}\n```" + ) + card_data = DatasetCardData( + dataset_summary=dataset_summary, + repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}", + pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}", + leaderboard_url=self.leaderboard_url, + point_of_contact=self.point_of_contact, + ) + card_metadata.to_dataset_card_data(card_data) + card = DatasetCard.from_template( + card_data, + pretty_name=card_data.pretty_name, + ) + card.push_to_hub(repo_id, repo_type="dataset") diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba795edb72d7b665a2c0fe6d4f3e3a5ed91b6940 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/utils.py @@ -0,0 +1,149 @@ +import logging +import os +import re +import subprocess +from importlib.metadata import version +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +from torch.utils.collect_env import get_pretty_env_info +from transformers import __version__ as trans_version + + +logger = logging.getLogger(__name__) + + +def remove_none_pattern(input_string: str) -> Tuple[str, bool]: + """Remove the ',none' substring from the input_string if it exists at the end. + + Args: + input_string (str): The input string from which to remove the ',none' substring. + + Returns: + Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed + and a boolean indicating whether the modification was made (True) or not (False). + """ + # Define the pattern to match ',none' at the end of the string + pattern = re.compile(r",none$") + + # Use sub() to replace ',none' with an empty string + result = re.sub(pattern, "", input_string) + + # check if the input_string changed + removed = result != input_string + + return result, removed + + +def _handle_non_serializable(o: Any) -> Union[int, str, list]: + """Handle non-serializable objects by converting them to serializable types. + + Args: + o (Any): The object to be handled. + + Returns: + Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32, + it will be converted to int. If the object is of type set, it will be converted + to a list. Otherwise, it will be converted to str. + """ + if isinstance(o, np.int64) or isinstance(o, np.int32): + return int(o) + elif isinstance(o, set): + return list(o) + else: + return str(o) + + +def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]: + try: + git_folder = Path(repo_path, ".git") + if git_folder.is_file(): + git_folder = Path( + git_folder.parent, + git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1], + ) + if Path(git_folder, "HEAD").exists(): + head_name = ( + Path(git_folder, "HEAD") + .read_text(encoding="utf-8") + .split("\n")[0] + .split(" ")[-1] + ) + head_ref = Path(git_folder, head_name) + git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "") + else: + git_hash = None + except Exception as err: + logger.debug( + f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}" + ) + return None + return git_hash + + +def get_git_commit_hash(): + """ + Gets the git commit hash of your current repo (if it exists). + Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42 + """ + try: + git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() + git_hash = git_hash.decode() + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError occurs when git not installed on system + git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists + return git_hash + + +def add_env_info(storage: Dict[str, Any]): + try: + pretty_env_info = get_pretty_env_info() + except Exception as err: + pretty_env_info = str(err) + try: + dllm_eval_version = version("dllm_eval") + except Exception as err: + dllm_eval_version = str(err) + transformers_version = trans_version + upper_dir_commit = get_commit_from_path( + Path(os.getcwd(), "..") + ) # git hash of upper repo if exists + added_info = { + "pretty_env_info": pretty_env_info, + "transformers_version": transformers_version, + "dllm_eval_version": dllm_eval_version, + "upper_git_hash": upper_dir_commit, # in case this repo is submodule + } + storage.update(added_info) + + +def add_tokenizer_info(storage: Dict[str, Any], lm): + if getattr(lm, "tokenizer", False): + try: + tokenizer_info = { + "tokenizer_pad_token": [ + lm.tokenizer.pad_token, + str(lm.tokenizer.pad_token_id), + ], + "tokenizer_eos_token": [ + lm.tokenizer.eos_token, + str(lm.tokenizer.eos_token_id), + ], + "tokenizer_bos_token": [ + lm.tokenizer.bos_token, + str(lm.tokenizer.bos_token_id), + ], + "eot_token_id": getattr(lm, "eot_token_id", None), + "max_length": getattr(lm, "max_length", None), + } + storage.update(tokenizer_info) + except Exception as err: + logger.debug( + f"Logging detailed tokenizer info failed with {err}, skipping..." + ) + # seems gguf and textsynth do not have tokenizer + else: + logger.debug( + "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results." + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/wandb_logger.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/wandb_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0859b3c8e90437f21b6f06143b14941a7a96d2 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/loggers/wandb_logger.py @@ -0,0 +1,358 @@ +import copy +import json +import logging +from typing import Any, Dict, List, Literal, Tuple + +import numpy as np +import pandas as pd +from packaging.version import Version + +from dllm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern + + +logger = logging.getLogger(__name__) + + +def get_wandb_printer() -> Literal["Printer"]: + """Returns a wandb printer instance for pretty stdout.""" + from wandb.sdk.lib.printer import new_printer + + printer = new_printer() + return printer + + +class WandbLogger: + def __init__(self, init_args=None, config_args=None) -> None: + """Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update() + + Args: + init_args Optional[Dict]: Arguments for init configuration. + config_args Optional[Dict]: Arguments for config + + Parse and log the results returned from evaluator.simple_evaluate() with: + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + wandb_logger.log_eval_samples(results["samples"]) + """ + try: + import wandb + + assert Version(wandb.__version__) >= Version("0.13.6") + if Version(wandb.__version__) < Version("0.13.6"): + wandb.require("report-editing:v0") + except Exception as e: + logger.warning( + "To use the wandb reporting functionality please install wandb>=0.13.6.\n" + "To install the latest version of wandb run `pip install wandb --upgrade`\n" + f"{e}" + ) + + self.wandb_args: Dict[str, Any] = init_args or {} + self.wandb_config_args: Dict[str, Any] = config_args or {} + + # pop the step key from the args to save for all logging calls + self.step = self.wandb_args.pop("step", None) + + # initialize a W&B run + if wandb.run is None: + self.run = wandb.init(**self.wandb_args) + if self.wandb_config_args: + self.run.config.update(self.wandb_config_args) + else: + self.run = wandb.run + + self.printer = get_wandb_printer() + + def post_init(self, results: Dict[str, Any]) -> None: + self.results: Dict[str, Any] = copy.deepcopy(results) + self.task_names: List[str] = list(results.get("results", {}).keys()) + self.group_names: List[str] = list(results.get("groups", {}).keys()) + + def _get_config(self) -> Dict[str, Any]: + """Get configuration parameters.""" + self.task_configs = self.results.get("configs", {}) + cli_configs = self.results.get("config", {}) + configs = { + "task_configs": self.task_configs, + "cli_configs": cli_configs, + } + + return configs + + def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]: + """Sanitize the results dictionary.""" + _results = copy.deepcopy(self.results.get("results", dict())) + + # Remove None from the metric string name + tmp_results = copy.deepcopy(_results) + for task_name in self.task_names: + task_result = tmp_results.get(task_name, dict()) + for metric_name, metric_value in task_result.items(): + _metric_name, removed = remove_none_pattern(metric_name) + if removed: + _results[task_name][_metric_name] = metric_value + _results[task_name].pop(metric_name) + + # remove string valued keys from the results dict + wandb_summary = {} + for task in self.task_names: + task_result = _results.get(task, dict()) + for metric_name, metric_value in task_result.items(): + if isinstance(metric_value, str): + wandb_summary[f"{task}/{metric_name}"] = metric_value + + for summary_metric, summary_value in wandb_summary.items(): + _task, _summary_metric = summary_metric.split("/") + _results[_task].pop(_summary_metric) + + tmp_results = copy.deepcopy(_results) + for task_name, task_results in tmp_results.items(): + for metric_name, metric_value in task_results.items(): + _results[f"{task_name}/{metric_name}"] = metric_value + _results[task_name].pop(metric_name) + for task in self.task_names: + _results.pop(task) + + return wandb_summary, _results + + def _log_results_as_table(self) -> None: + """Generate and log evaluation results as a table to W&B.""" + columns = [ + "Version", + "Filter", + "num_fewshot", + "Metric", + "Value", + "Stderr", + ] + + def make_table(columns: List[str], key: str = "results"): + import wandb + + table = wandb.Table(columns=columns) + results = copy.deepcopy(self.results) + + for k, dic in results.get(key).items(): + if k in self.group_names and not key == "groups": + continue + version = results.get("versions").get(k) + if version == "N/A": + version = None + n = results.get("n-shot").get(k) + + for (mf), v in dic.items(): + m, _, f = mf.partition(",") + if m.endswith("_stderr"): + continue + if m == "alias": + continue + + if m + "_stderr" + "," + f in dic: + se = dic[m + "_stderr" + "," + f] + if se != "N/A": + se = "%.4f" % se + table.add_data(*[k, version, f, n, m, str(v), str(se)]) + else: + table.add_data(*[k, version, f, n, m, str(v), ""]) + + return table + + # log the complete eval result to W&B Table + table = make_table(["Tasks"] + columns, "results") + self.run.log({"evaluation/eval_results": table}, step=self.step) + + if "groups" in self.results.keys(): + table = make_table(["Groups"] + columns, "groups") + self.run.log({"evaluation/group_eval_results": table}, step=self.step) + + def _log_results_as_artifact(self) -> None: + """Log results as JSON artifact to W&B.""" + import wandb + + dumped = json.dumps( + self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False + ) + artifact = wandb.Artifact("results", type="eval_results") + with artifact.new_file("results.json", mode="w", encoding="utf-8") as f: + f.write(dumped) + self.run.log_artifact(artifact) + + def log_eval_result(self) -> None: + """Log evaluation results to W&B.""" + # Log configs to wandb + configs = self._get_config() + self.run.config.update(configs, allow_val_change=self.step is not None) + + wandb_summary, self.wandb_results = self._sanitize_results_dict() + # update wandb.run.summary with items that were removed + self.run.summary.update(wandb_summary) + # Log the evaluation metrics to wandb + self.run.log(self.wandb_results, step=self.step) + # Log the evaluation metrics as W&B Table + self._log_results_as_table() + # Log the results dict as json to W&B Artifacts + self._log_results_as_artifact() + + def _generate_dataset( + self, data: List[Dict[str, Any]], config: Dict[str, Any] + ) -> pd.DataFrame: + """Generate a dataset from evaluation data. + + Args: + data (List[Dict[str, Any]]): The data to generate a dataset for. + config (Dict[str, Any]): The configuration of the task. + + Returns: + pd.DataFrame: A dataframe that is ready to be uploaded to W&B. + """ + ids = [x["doc_id"] for x in data] + labels = [x["target"] for x in data] + instance = [""] * len(ids) + resps = [""] * len(ids) + filtered_resps = [""] * len(ids) + model_outputs = {} + + metrics_list = config["metric_list"] + metrics = {} + for metric in metrics_list: + metric = metric.get("metric") + if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data] + if metric in ["byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_bytes"] = [x[metric][1] for x in data] + else: + metrics[f"{metric}_words"] = [x[metric][1] for x in data] + else: + metrics[metric] = [x[metric] for x in data] + + if config["output_type"] == "loglikelihood": + instance = [x["arguments"][0][0] for x in data] + labels = [x["arguments"][0][1] for x in data] + resps = [ + f"log probability of continuation is {x['resps'][0][0][0]} " + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["resps"][0][0][1] else "be" + ) + for x in data + ] + filtered_resps = [ + f"log probability of continuation is {x['filtered_resps'][0][0]} " + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["filtered_resps"][0][1] else "be" + ) + for x in data + ] + elif config["output_type"] == "multiple_choice": + instance = [x["arguments"][0][0] for x in data] + choices = [ + "\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])]) + for x in data + ] + resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data] + filtered_resps = [ + np.argmax([n[0] for n in x["filtered_resps"]]) for x in data + ] + elif config["output_type"] == "loglikelihood_rolling": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + elif config["output_type"] == "generate_until": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + + model_outputs["raw_predictions"] = resps + model_outputs["filtered_predictions"] = filtered_resps + + df_data = { + "id": ids, + "data": instance, + } + if config["output_type"] == "multiple_choice": + df_data["choices"] = choices + + tmp_data = { + "input_len": [len(x) for x in instance], + "labels": labels, + "output_type": config["output_type"], + } + df_data.update(tmp_data) + df_data.update(model_outputs) + df_data.update(metrics) + + return pd.DataFrame(df_data) + + def _log_samples_as_artifact( + self, data: List[Dict[str, Any]], task_name: str + ) -> None: + import wandb + + # log the samples as an artifact + dumped = json.dumps( + data, + indent=2, + default=_handle_non_serializable, + ensure_ascii=False, + ) + artifact = wandb.Artifact(f"{task_name}", type="samples_by_task") + with artifact.new_file( + f"{task_name}_eval_samples.json", mode="w", encoding="utf-8" + ) as f: + f.write(dumped) + self.run.log_artifact(artifact) + # artifact.wait() + + def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None: + """Log evaluation samples to W&B. + + Args: + samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task. + """ + task_names: List[str] = [ + x for x in self.task_names if x not in self.group_names + ] + + ungrouped_tasks = [] + tasks_by_groups = {} + + for task_name in task_names: + group_names = self.task_configs[task_name].get("group", None) + if group_names: + if isinstance(group_names, str): + group_names = [group_names] + + for group_name in group_names: + if not tasks_by_groups.get(group_name): + tasks_by_groups[group_name] = [task_name] + else: + tasks_by_groups[group_name].append(task_name) + else: + ungrouped_tasks.append(task_name) + + for task_name in ungrouped_tasks: + eval_preds = samples[task_name] + + # log the samples as a W&B Table + df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) + self.run.log({f"{task_name}_eval_results": df}, step=self.step) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + for group, grouped_tasks in tasks_by_groups.items(): + grouped_df = pd.DataFrame() + for task_name in grouped_tasks: + eval_preds = samples[task_name] + df = self._generate_dataset( + eval_preds, self.task_configs.get(task_name) + ) + df["group"] = group + df["task"] = task_name + grouped_df = pd.concat([grouped_df, df], ignore_index=True) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + self.run.log({f"{group}_eval_results": grouped_df}, step=self.step) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/LLaDA2.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/LLaDA2.py new file mode 100644 index 0000000000000000000000000000000000000000..783400310b4342ecc7a671926fa8e7afe3b05620 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/LLaDA2.py @@ -0,0 +1,726 @@ +import logging +import os +from datetime import timedelta +from typing import Dict, List, Literal, Optional, Tuple, Union, TypeVar +import torch +import torch.nn.functional as F +import numpy as np +import transformers +import json +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, +) +from datasets import Dataset +from accelerate.utils import get_max_memory +from packaging import version +from tqdm import tqdm +import torch.distributed as dist +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, +) +from dllm_eval.api.instance import Instance +from dllm_eval.api.model import LM, TemplateLM +from dllm_eval.api.registry import register_model +from dllm_eval.models.utils import get_dtype, configure_pad_token + +try: + from .hts_sampler import HTSSampler +except ImportError: + HTSSampler = None + +eval_logger = logging.getLogger(__name__) +T = TypeVar("T", bound="LM") + + +def add_gumbel_noise(logits, temperature): + if temperature == 0.0: + return logits + logits = logits.to(torch.float32) + noise = torch.rand_like(logits, dtype=torch.float32) + gumbel_noise = (-torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + mask_num = mask_index.sum(dim=1, keepdim=True) + base = mask_num // steps + remainder = mask_num % steps + num_transfer_tokens = base.expand(-1, steps).clone() + if remainder.sum() > 0: + indices = torch.arange(steps, device=mask_index.device) + mask = indices.unsqueeze(0) < remainder + num_transfer_tokens[mask] += 1 + return num_transfer_tokens.to(torch.int64) + + +@register_model("LLaDA2") +class LLaDA2(TemplateLM): + AUTO_MODEL_CLASS = transformers.AutoModel + _DEFAULT_MAX_LENGTH = 20480 + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + backend: Literal["default", "causal", "seq2seq"] = "causal", + revision: Optional[str] = "main", + subfolder: Optional[str] = None, + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ] = None, + truncation: Optional[bool] = False, + logits_cache: bool = True, + max_length: Optional[int] = None, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + batch_size: Optional[Union[int]] = 1, + max_batch_size: Optional[int] = 64, + trust_remote_code: Optional[bool] = True, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + escape_until:Optional[bool] = False, + prefix_token_id: Optional[int] = None, + parallelize: Optional[bool] = False, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + mc_num: int = 1024, + remasking: str = "low_confidence", + mask_id: int = 156895, + is_check_greedy : bool =True, + assistant_prefix: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + self.mc_num = mc_num + self.mask_id = mask_id + self.remasking = remasking + self.pretrained = pretrained + self.is_check_greedy = is_check_greedy + self.assistant_prefix = assistant_prefix + self.add_bos_token = add_bos_token + self.escape_until = escape_until + if not isinstance(pretrained, str): + eval_logger.warning( + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." + ) + assert not parallelize, ( + "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" + ) + self._model = pretrained + self._device = self._model.device + self._config = self._model.config + gpus = 0 + + else: + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self.accelerator = accelerator + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + if not (parallelize or accelerator.num_processes > 1): + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + self._device = ( + self.accelerator.device + if hasattr(self, "accelerator") + else torch.device(device) + ) + revision = str(revision) + revision = revision + ("/" + subfolder if subfolder is not None else "") + self._get_config( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + ) + self._get_backend( + config=self.config, backend=backend, trust_remote_code=trust_remote_code + ) + self._create_tokenizer( + pretrained, + tokenizer, + revision=revision, + trust_remote_code=trust_remote_code, + use_fast_tokenizer=use_fast_tokenizer, + gguf_file=gguf_file, + add_bos_token=add_bos_token, + ) + if isinstance(pretrained, str): + self._create_model( + pretrained=pretrained, + revision=revision, + dtype=dtype, + trust_remote_code=trust_remote_code, + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + peft=peft, + delta=delta, + autogptq=autogptq, + gptqmodel=gptqmodel, + gguf_file=gguf_file, + **kwargs, + ) + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) + self.add_bos_token = add_bos_token + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." + ) + self._max_length = max_length + self.pretrained = pretrained + self.delta = delta + self.peft = peft + self.revision = revision + self.batch_schedule = 1 + self.batch_sizes = {} + self.max_batch_size = max_batch_size + if str(batch_size).startswith("auto"): + batch_size = batch_size.split(":") + self.batch_size_per_gpu = batch_size[0] + self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 + else: + self.batch_size_per_gpu = int(batch_size) + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + if not (parallelize or autogptq or hasattr(self, "accelerator")): + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." + ) + if gpus > 1: + if hasattr(self, "accelerator") and self.accelerator.num_processes > 1: + if parallelize: + eval_logger.warning( + "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." + ) + elif gpus > self.accelerator.num_processes: + eval_logger.warning( + "WARNING: The number of total system GPUs does not match the number of spawned processes. " + "If you would like to use data parallelism, please launch the script " + "with 'accelerate launch *script*'. " + f"Current run will proceed with {self.accelerator.num_processes} devices." + ) + if self.accelerator.is_local_main_process: + eval_logger.info( + f"Using {gpus} devices with data parallelism" + ) + + self._device = torch.device(f"{self.accelerator.device}") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self._rank = 0 + self._world_size = 1 + else: + self._rank = 0 + self._world_size = 1 + else: + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" + ) + self._rank = 0 + self._world_size = 1 + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + self.is_first_inference = True + + if HTSSampler is not None: + self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device) + eval_logger.info("HTSSampler initialized successfully.") + + @property + def rank(self): + if hasattr(self, "_rank"): + return self._rank + if hasattr(self, "accelerator"): + return self.accelerator.local_process_index + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def world_size(self): + if hasattr(self, "_world_size"): + return self._world_size + if hasattr(self, "accelerator"): + return self.accelerator.num_processes + return int(os.environ.get("WORLD_SIZE", 1)) + + def _get_accelerate_args( + self, + parallelize: Optional[bool] = None, + device_map: Optional[str] = "auto", + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + gpus: Optional[int] = None, + ) -> dict: + num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + if parallelize is None and gpus is not None and gpus > 1: + parallelize = True + args = {} + if parallelize: + max_memory_all_gpus = get_max_memory() + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + max_memory_per_gpu_map = { + device_idx: max_memory_per_gpu for device_idx in range(len(max_memory_all_gpus)) + } if max_memory_per_gpu is not None else {k: v for k, v in max_memory_all_gpus.items()} + if hasattr(self, "accelerator"): + max_memory_per_gpu_map = { + k: v for k, v in max_memory_all_gpus.items() if k % num_local_processes == self.accelerator.process_index % num_local_processes + } + args["max_memory"] = max_memory_per_gpu_map + args["device_map"] = "auto" + args["offload_folder"] = offload_folder + if max_cpu_memory is not None: + args["max_memory"]["cpu"] = max_cpu_memory + eval_logger.info( + f"Model parallel set to True. Max memory per GPU: {args['max_memory']}, Device map: {args['device_map']}" + ) + else: + args["device_map"] = {"": str(self.device)} + eval_logger.info( + f"Model parallel set to False. Device map: {args['device_map']}" + ) + return args + + @property + def config(self): + return self._config + + @property + def model(self): + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length > 1e10: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _get_backend( + self, + config: Union[transformers.PretrainedConfig, transformers.AutoConfig], + backend: Literal["default", "causal", "seq2seq"] = "default", + trust_remote_code: Optional[bool] = False, + ) -> None: + assert backend in ["default", "causal", "seq2seq"] + if backend != "default": + self.backend = backend + eval_logger.info( + f"Overrode HF model backend type, and using type '{self.backend}'" + ) + else: + if ( + getattr(config, "model_type") + in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "seq2seq" + elif ( + getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "causal" + else: + eval_logger.warning( + "HF model type is neither CausalLM nor Seq2SeqLM. Assuming CausalLM." + ) + self.backend = "causal" + + def _get_config( + self, + pretrained: str, + revision: str = "main", + trust_remote_code: bool = False, + gguf_file: Optional[str] = None, + ) -> None: + self._config = transformers.AutoConfig.from_pretrained( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + ) + + def _create_model( + self, + pretrained: str, + revision: Optional[str] = "main", + dtype: Optional[Union[str, torch.dtype]] = "auto", + trust_remote_code: Optional[bool] = False, + parallelize: Optional[bool] = False, + gpus: Optional[int] = None, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + **kwargs, + ) -> None: + if autogptq or gptqmodel: + raise NotImplementedError("Quantization options are not implemented for this custom class.") + model_dtype = get_dtype(dtype) + eval_logger.info(f"Loading model with dtype: {model_dtype}") + model_kwargs = kwargs if kwargs else {} + if not parallelize: + model_kwargs.update( + self._get_accelerate_args( + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + ) + ) + self._model = transformers.AutoModelForCausalLM.from_pretrained( + pretrained, + revision=revision, + torch_dtype=model_dtype, + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + if peft: + from peft import PeftModel + eval_logger.info(f"Loading PEFT model from {peft}") + self._model = PeftModel.from_pretrained(self._model, peft, torch_dtype=model_dtype) + if not parallelize: + self._model = self._model.to(self.device) + self._model = self._model.to(torch.bfloat16) + self._model.eval() + + def _create_tokenizer( + self, + pretrained: Union[str, transformers.PreTrainedModel], + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ], + revision: Optional[str] = "main", + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + gguf_file: Optional[str] = None, + add_bos_token: Optional[bool] = False, + ) -> None: + kwargs = { + "revision": revision, + "trust_remote_code": trust_remote_code, + "use_fast": use_fast_tokenizer + } + if add_bos_token: + kwargs["add_bos_token"] = True + if tokenizer: + if isinstance(tokenizer, str): + self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, **kwargs) + else: + self.tokenizer = tokenizer + else: + model_name = pretrained if isinstance(pretrained, str) else self.model.name_or_path + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, **kwargs) + + def tok_encode( + self, string: str, left_truncate_len=None, add_special_tokens=None + ) -> List[int]: + special_tokens_kwargs = {} + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs["add_special_tokens"] = self.add_bos_token + else: + special_tokens_kwargs["add_special_tokens"] = add_special_tokens + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_batch_encode( + self, + strings: List[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + add_special_tokens = {"add_special_tokens": self.add_bos_token} if self.backend == "causal" else {} + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len and encoding["input_ids"].size(1) > left_truncate_len: + eval_logger.warning( + f"Left-truncating from {encoding['input_ids'].size(1)} to {left_truncate_len} tokens." + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self.tokenizer.padding_side = old_padding_side + return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device) + + def tok_decode(self, tokens, skip_special_tokens=False): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _model_call(self, inps, attn_mask=None, labels=None): + with torch.no_grad(): + if self.backend == "seq2seq": + return self.model(input_ids=inps, attention_mask=attn_mask, labels=labels).logits + else: + return self.model(inps, attention_mask=attn_mask).logits + + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[float]: + raise NotImplementedError + + def loglikelihood(self, requests): + raise NotImplementedError + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + gen_kwargs = requests[0].args[1] + use_hts = gen_kwargs.get("use_hts", False) + + realtime_output = gen_kwargs.get("realtime_output", "realtime_hts_results.jsonl") + baseline_realtime_output = "realtime_baseline_results.jsonl" + + if not use_hts: + bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running Baseline") + ds_data = [{"text": req.args[0]} for req in requests] + ds = Dataset.from_list(ds_data) + + req_idx = 0 + for batch in ds.iter(batch_size=int(self.batch_size)): + contexts = batch["text"] + context_enc, _ = self.tok_batch_encode(contexts) + prompt_length = context_enc.shape[1] + + out_full = self.model.generate( + inputs=context_enc, + steps=gen_kwargs.get("steps", 32), + gen_length=gen_kwargs.get("gen_length", 512), + block_length=gen_kwargs.get("block_length", 32), + temperature=gen_kwargs.get("temperature", 0.7), + eos_early_stop=gen_kwargs.get("eos_early_stop", False), + ) + generated_tokens = out_full[:, prompt_length:] + cont_toks_list = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + + for i, s in enumerate(cont_toks_list): + s = s.strip() + + if not self.escape_until: + until_terms = gen_kwargs.get("until", []) + for term in until_terms: + if len(term) > 0 and term in s: + s = s.split(term)[0] + + orig_req = requests[req_idx] + target_val = getattr(orig_req, "target", None) + if target_val is None or target_val == "N/A": + if "test" in orig_req.doc and "entry_point" in orig_req.doc: + target_val = orig_req.doc["test"] + "\ncheck(" + orig_req.doc["entry_point"] + ")" + else: + target_val = orig_req.doc.get("answer", orig_req.doc.get("solution", "N/A")) + + with open(baseline_realtime_output, "a", encoding="utf-8") as f: + f.write(json.dumps({ + "doc": orig_req.doc, + "target": target_val, + "resps": [[s]], + "prompt": contexts[i] + }, ensure_ascii=False) + "\n") + f.flush() + + res.append(s) + bar.update(1) + req_idx += 1 + bar.close() + + else: + bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running HTS+SVF") + for req in requests: + prompt_text = req.args[0] + context_enc, _ = self.tok_batch_encode([prompt_text]) + + p_interval = int(gen_kwargs.get("pruning_interval", 0)) + + final_codes, stats = self.hts_sampler.generate_hts( + prompt_text=prompt_text, + input_ids=context_enc, + initial_N=int(gen_kwargs.get("hts_N", 4)), + final_K=int(gen_kwargs.get("final_K", 1)), + hts_survivor_k=int(gen_kwargs.get("hts_survivor_k", 4)), + hts_mode=gen_kwargs.get("hts_mode", True), + hts_start_pct=float(gen_kwargs.get("hts_start_pct", 0.1)), + hts_end_pct=float(gen_kwargs.get("hts_end_pct", 0.6)), + decay_factor=float(gen_kwargs.get("decay_factor", 1.5)), + pruning_interval=p_interval, + reward_mode=gen_kwargs.get("reward_mode", "svf"), + task_type=gen_kwargs.get("task_type", "code"), + steps=int(gen_kwargs.get("steps", 32)), + gen_length=int(gen_kwargs.get("gen_length", 512)), + block_length=int(gen_kwargs.get("block_length", 32)), + temperature=float(gen_kwargs.get("temperature", 0.7)), + top_p=float(gen_kwargs.get("top_p", 0.95)), + top_k=gen_kwargs.get("top_k", None), + threshold=float(gen_kwargs.get("threshold", 0.85)), + mask_id=self.mask_id, + eos_id=self.eot_token_id + ) + + processed_codes = [] + for code in final_codes: + code = code.strip() + if not self.escape_until: + until_terms = gen_kwargs.get("until", []) + for term in until_terms: + if len(term) > 0 and term in code: + code = code.split(term)[0] + processed_codes.append(code) + + final_choice = processed_codes[0] + res.append(final_choice) + + target_val = getattr(req, "target", None) + if target_val is None or target_val == "N/A": + if "test" in req.doc and "entry_point" in req.doc: + target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")" + else: + target_val = req.doc.get("answer", req.doc.get("solution", "N/A")) + + with open(realtime_output, "a", encoding="utf-8") as f: + all_resps = [[code] for code in processed_codes] + + output_data = { + "doc": req.doc, + "target": target_val, + "resps": all_resps, + "prompt": prompt_text, + "entropy_history": stats.get("entropy_history", []), + "pruning_history": stats.get("pruning_history", []), + "final_scores": stats.get("final_scores", []), + "all_trajectories": stats.get("all_trajectories", []), + "nfe": stats.get("nfe", 0), + "svf_calls": stats.get("svf_calls", 0), + "total_steps": stats.get("total_steps", 0) + } + f.write(json.dumps(output_data, ensure_ascii=False) + "\n") + f.flush() + + bar.update(1) + bar.close() + + return res + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True + ) -> str: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + if self.assistant_prefix: + chat_templated += self.assistant_prefix + return chat_templated \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b229acb5fb21c2f423fcd43a8a235b5a0d12239 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/__init__.py @@ -0,0 +1,19 @@ +from . import ( + LLaDA2, + huggingface, +) +# from .configuration_llada import LLaDAConfig +# from .modeling_llada import LLaDAModelLM + + +try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + import huggingface_hub.constants # type: ignore + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True +except ImportError: + pass + + +# __all__ = ['LLaDAConfig', 'LLaDAModelLM'] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/dummy.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..4702a36cb29809c9dd08c516b99e74e71ffcc166 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/dummy.py @@ -0,0 +1,41 @@ +import random + +from tqdm import tqdm + +from dllm_eval.api.model import LM +from dllm_eval.api.registry import register_model + + +@register_model("dummy") +class DummyLM(LM): + def __init__(self) -> None: + super().__init__() + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + return cls() + + def loglikelihood(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append((-random.random(), False)) + + return res + + def generate_until(self, requests, disable_tqdm: bool = False): + res = [] + + for request in tqdm(requests, disable=disable_tqdm): + res.append("lol") + assert request.arguments[0].strip() != "" + + return res + + def loglikelihood_rolling(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append(-random.random()) + + return res diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/hts_sampler.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/hts_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..e35ad551d7074f6b49ea4e2ca83ede270649c04b --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/hts_sampler.py @@ -0,0 +1,323 @@ +import torch +import torch.nn.functional as F +import numpy as np +from .verifier import CodeVerifier +import logging +import re +import math + +logger = logging.getLogger(__name__) + +class HTSSampler: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + self.verifier = CodeVerifier(model, tokenizer, device) + + def _get_num_transfer_tokens(self, block_length, steps): + if steps == 0: return torch.tensor([], dtype=torch.int64) + base = block_length // steps + remainder = block_length % steps + num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64) + num_transfer_tokens[:remainder] += 1 + return num_transfer_tokens + + def _sample_with_temperature(self, logits, temperature, top_k, top_p): + logits = logits.to(torch.float32) + + orig_probs = torch.softmax(logits, dim=-1) + x0_p, _ = torch.max(orig_probs, dim=-1) + + if temperature > 0.0: + noise = torch.rand_like(logits, dtype=torch.float32) + gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10) + logits = logits / temperature + gumbel_noise + + if top_k is not None and top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = -float('Inf') + + x0 = torch.argmax(logits, dim=-1) + + return x0, x0_p + + def _safe_scalar(self, val): + if isinstance(val, torch.Tensor): + if val.numel() > 1: return val.mean().item() + return val.item() + return float(val) + + def _analyze_structure(self, text, task_type="code"): + score = 0.0 + stripped = text.strip() + if task_type == "code": + if len(stripped) < 5: return -0.1 + keywords = ["return", "print", "yield", "lambda", "class ", "def "] + if any(k in stripped for k in keywords): score += 0.05 + if ":" in stripped: score += 0.02 + if " " in text: score += 0.03 + elif task_type == "math": + if "\\boxed{" in stripped: score += 0.1 + if "The answer is" in stripped: score += 0.05 + if len(stripped) < 10: return -0.1 + if "Step" in stripped and stripped.count("Step") > 15: score -= 0.2 + return score + + def _chunked_forward(self, x, chunk_size=32, slice_start=None): + total_batch = x.shape[0] + logits_list = [] + for i in range(0, total_batch, chunk_size): + end_idx = min(i + chunk_size, total_batch) + sub_x = x[i:end_idx] + with torch.no_grad(): + outputs = self.model(input_ids=sub_x) + sub_logits = outputs.logits + if slice_start is not None: + s_start = slice_start if slice_start >= 0 else sub_logits.shape[1] + slice_start + sub_logits = sub_logits[:, s_start:, :] + logits_list.append(sub_logits.detach().clone()) + return torch.cat(logits_list, dim=0) + + def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id, + prompt_length, resample_window=5, task_type="code"): + num_survivors = len(survivor_indices) + if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone() + + if task_type == "math": resample_window = 12 + elif task_type == "reasoning": resample_window = 10 + elif task_type == "code": resample_window = 6 + + base_repeat = target_width // num_survivors + remainder = target_width % num_survivors + new_x_list = [] + new_conf_list = [] + + for i in range(num_survivors): + count = base_repeat + (1 if i < remainder else 0) + if count == 0: continue + + survivor_x = x[survivor_indices[i]] + survivor_conf = conf_scores[survivor_indices[i]] + + new_x_list.append(survivor_x.unsqueeze(0)) + new_conf_list.append(survivor_conf.unsqueeze(0)) + + if count > 1: + gen_part = survivor_x[prompt_length:] + gen_conf = survivor_conf[prompt_length:] + non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0] + + for _ in range(count - 1): + perturbed_x = survivor_x.clone() + perturbed_conf = survivor_conf.clone() + + if len(non_mask_indices) > 0: + pool_size = min(resample_window * 2, len(non_mask_indices)) + current_token_confs = gen_conf[non_mask_indices] + + _, candidate_indices = torch.topk(current_token_confs, k=pool_size, largest=False) + + num_to_perturb = min(resample_window, pool_size) + rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb] + selected_sub_indices = candidate_indices[rand_indices] + + target_indices_in_x = prompt_length + non_mask_indices[selected_sub_indices] + perturbed_x[target_indices_in_x] = mask_id + perturbed_conf[target_indices_in_x] = 0.0 + + new_x_list.append(perturbed_x.unsqueeze(0)) + new_conf_list.append(perturbed_conf.unsqueeze(0)) + + return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0) + + @torch.no_grad() + def generate_hts(self, prompt_text, input_ids, problem_data=None, + initial_N=1, final_K=1, survivor_K=None, + prune_step_pct=0.0, reward_mode="confidence", + temperature=0.7, block_length=32, steps=64, gen_length=1024, + top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9, + eos_id=156892, mask_id=156895, + hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.5, + hts_survivor_k=4, task_type="code", until=None, pruning_interval=0): + + input_ids = input_ids.to(self.device) + if input_ids.shape[0] == 1: input_ids = input_ids.repeat(initial_N, 1) + + schedule_map = {} + ts_start, tr_end = 0, 0 + if not hts_mode: + final_K_list = [final_K] if not isinstance(final_K, list) else final_K + prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct + survivor_K_list = final_K_list if survivor_K is None else ([survivor_K] if not isinstance(survivor_K, list) else survivor_K) + if len(survivor_K_list) < len(final_K_list): survivor_K_list.extend(final_K_list[len(survivor_K_list):]) + for pct, width, parents in zip(prune_pct_list, final_K_list, survivor_K_list): + if pct > 0: + s = int(steps * pct) + schedule_map[s] = (width, parents) + else: + final_K_list = [final_K] if not isinstance(final_K, int) else [final_K] + ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct) + + prompt_length = input_ids.shape[1] + num_blocks = (prompt_length + gen_length + block_length - 1) // block_length + total_length = num_blocks * block_length + + x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device) + x[:, :prompt_length] = input_ids.clone() + + conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device) + conf_scores[:, :prompt_length] = 1.0 + + prefill_blocks = prompt_length // block_length + num_gen_blocks = max(1, num_blocks - prefill_blocks) + current_bsz = initial_N + + next_allowed_pruning_step = ts_start if hts_mode else 0 + + stats = { + "initial_n": initial_N, "final_k": final_K_list[-1], + "pruning_history": [], "entropy_history": [], "nfe": 0.0, + "svf_calls": 0, "final_scores": [], "total_steps": steps + } + + for num_block in range(prefill_blocks, num_blocks): + window_end = (num_block + 1) * block_length + schedule = self._get_num_transfer_tokens(block_length, steps) + + for step in range(steps): + cur_x = x[:current_bsz, :window_end] + + perform_pruning = False + num_parents_to_select = 0 + + if hts_mode and step >= next_allowed_pruning_step and step < tr_end: + target_width = max(final_K_list[-1], math.ceil(initial_N * (decay_factor ** -(step - ts_start)))) + if current_bsz > target_width: + perform_pruning = True + num_parents_to_select = hts_survivor_k + elif not hts_mode and step in schedule_map: + target_width, num_parents_to_select = schedule_map[step] + if current_bsz > target_width: perform_pruning = True + + if perform_pruning: + stats["nfe"] += current_bsz + stats["svf_calls"] += current_bsz + + gen_logits = self._chunked_forward(cur_x, chunk_size=16, slice_start=prompt_length) + rough_ids = torch.argmax(gen_logits, dim=-1) + rough_codes_snippet = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True) + candidates = [] + for i in range(current_bsz): + full_code = rough_codes_snippet[i] + s = self._safe_scalar(self.verifier.get_reward(prompt_text, full_code, mode=reward_mode, problem_data=problem_data, current_logits=gen_logits[i] if reward_mode != "svf" else None, task_type=task_type)) + s += self._analyze_structure(full_code, task_type=task_type) + clean_content = full_code.strip().replace(" ", "").replace("\n", "") + candidates.append({'score': s, 'idx': i, 'key': hash(clean_content[:200] + clean_content[-200:])}) + + stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]}) + candidates.sort(key=lambda x: x['score'], reverse=True) + + selected_indices, seen_keys = [], set() + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['key'] not in seen_keys: + selected_indices.append(cand['idx']); seen_keys.add(cand['key']) + + if len(selected_indices) < num_parents_to_select: + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['idx'] not in selected_indices: selected_indices.append(cand['idx']) + + top_indices = torch.tensor(selected_indices, device=self.device) + x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type) + + current_bsz = target_width + cur_x = x[:current_bsz, :window_end] + next_allowed_pruning_step = step + 1 + pruning_interval + + active_mask = cur_x[:, -block_length:] == mask_id + + stats["nfe"] += current_bsz + + active_logits = self._chunked_forward(cur_x, chunk_size=32, slice_start=-block_length) + + with torch.no_grad(): + if len(stats["entropy_history"]) < 32: + probs_for_stats = torch.softmax(active_logits.float(), dim=-1) + entropy_per_branch = (-(probs_for_stats * torch.log(probs_for_stats + 1e-10)).sum(dim=-1).mean(dim=-1)).cpu().numpy().tolist() + stats["entropy_history"].append(entropy_per_branch) + + x0, x0_p = self._sample_with_temperature(active_logits, temperature, top_k, top_p) + + num_transfer = schedule[step].item() + confidence = torch.where(active_mask, x0_p, -torch.inf) + transfer_idx = torch.zeros_like(x0, dtype=torch.bool) + + for b in range(current_bsz): + k_transfer = min(num_transfer, active_mask[b].sum().item()) + active_indices = torch.where(active_mask[b])[0] + if (confidence[b] > threshold).sum().item() >= k_transfer: + conf_indices = torch.where((confidence[b] > threshold) & active_mask[b])[0]; transfer_idx[b, conf_indices] = True + elif len(active_indices) > 0: + _, topk_indices = torch.topk(confidence[b][active_indices], k=min(k_transfer, len(active_indices))); transfer_idx[b, active_indices[topk_indices]] = True + + if transfer_idx.any(): + cur_x[:, -block_length:][transfer_idx] = x0[transfer_idx] + conf_scores[:current_bsz, window_end-block_length:window_end][transfer_idx] = x0_p[transfer_idx] + + if task_type in ["math", "reasoning"]: + for b in range(current_bsz): + gen_span = cur_x[b, prompt_length:window_end] + text_snippet = self.tokenizer.decode(gen_span, skip_special_tokens=True) + should_stop = False + if task_type == "reasoning" and ("###" in text_snippet): + should_stop = True + if task_type == "math" and ("\\boxed{" in text_snippet and "}" in text_snippet.split("\\boxed{")[-1]): + should_stop = True + + if should_stop: + non_mask_indices = (gen_span != mask_id).nonzero(as_tuple=True)[0] + if len(non_mask_indices) > 0: + last_idx = non_mask_indices[-1].item() + if last_idx + 1 < len(gen_span): + gen_span[last_idx + 1:] = eos_id + cur_x[b, prompt_length:window_end] = gen_span + if window_end < total_length: + x[b, window_end:] = eos_id + conf_scores[b, window_end:] = 1.0 + + for b in range(current_bsz): + gen_window = cur_x[b, prompt_length:window_end] + eos_indices = (gen_window == eos_id).nonzero(as_tuple=True)[0] + if len(eos_indices) > 0: + first_eos_idx = eos_indices[0].item() + if first_eos_idx + 1 < len(gen_window): + gen_window[first_eos_idx + 1:] = eos_id + cur_x[b, prompt_length:window_end] = gen_window + + x = x[:current_bsz] + x[:, :window_end] = cur_x + + stats["nfe"] = int(round(stats["nfe"])) + + final_gen_tokens = x[:current_bsz, prompt_length:] + final_codes = self.tokenizer.batch_decode(final_gen_tokens, skip_special_tokens=True) + final_candidates = [] + + stats["svf_calls"] += len(final_codes) + + for i in range(len(final_codes)): + txt = final_codes[i] + if until: + for term in until: + if term in txt: txt = txt.split(term)[0] + s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type)) + s += self._analyze_structure(txt, task_type) + final_candidates.append({'resp': txt, 'score': s}) + + final_candidates.sort(key=lambda x: x['score'], reverse=True) + stats["final_scores"] = [c['score'] for c in final_candidates] + stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)] + + return [c['resp'] for c in final_candidates], stats \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/huggingface.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6e1e99e20aeed5b20f7cd2d7a8f9b76155330a --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/huggingface.py @@ -0,0 +1,1489 @@ +import copy +import logging +import os +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import jinja2 +import torch +import torch.nn.functional as F +import transformers +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, + find_executable_batch_size, +) +from accelerate.utils import get_max_memory +from huggingface_hub import HfApi +from packaging import version +from peft import PeftModel +from peft import __version__ as PEFT_VERSION +from tqdm import tqdm +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, +) + +from dllm_eval import utils +from dllm_eval.api.instance import Instance +from dllm_eval.api.model import TemplateLM +from dllm_eval.api.registry import register_model +from dllm_eval.models.utils import ( + Collator, + clear_torch_cache, + configure_pad_token, + get_dtype, + handle_stop_sequences, + pad_and_concat, + stop_sequences_criteria, +) + + +eval_logger = logging.getLogger(__name__) + + +@register_model("hf-auto", "hf", "huggingface") +class HFLM(TemplateLM): + """ + An abstracted Huggingface model class. Enables usage with both models of + `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. + + Supports data-parallel multi-GPU with HF Accelerate. + """ + + AUTO_MODEL_CLASS = None + _DEFAULT_MAX_LENGTH = 2048 + + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + backend: Literal["default", "causal", "seq2seq"] = "default", + # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) + revision: Optional[str] = "main", + subfolder: str = "", + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ] = None, + truncation: Optional[bool] = False, + logits_cache: bool = True, + max_length: Optional[int] = None, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + softmax_dtype: Optional[Union[str, torch.dtype]] = None, + batch_size: Optional[Union[int, str]] = 1, + max_batch_size: Optional[int] = 64, + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + prefix_token_id: Optional[int] = None, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + parallelize: Optional[bool] = False, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + # optionally: take in an already-initialized transformers.PreTrainedModel + if not isinstance(pretrained, str): + eval_logger.warning( + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." + ) + assert not parallelize, ( + "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" + ) + self._model = pretrained + self._device = self._model.device + self._config = self._model.config + gpus = 0 + + else: + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self.accelerator = accelerator + + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + + # using one process with no model parallelism + if not (parallelize or accelerator.num_processes > 1): + # use user-passed device + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: # Parallelism managed by accelerate + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + # TODO: include in warning that `load_in_8bit` etc. affect this too + self._device = ( + self.accelerator.device + if hasattr(self, "accelerator") + else torch.device(device) + ) + + revision = str(revision) # cast to string if not already one + + self._get_config( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + subfolder=subfolder, + ) + + # determine which of 'causal' and 'seq2seq' backends to use for HF models + self._get_backend( + config=self.config, backend=backend, trust_remote_code=trust_remote_code + ) + + # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT + self._create_tokenizer( + pretrained, + tokenizer, + revision=revision, + subfolder=subfolder, + trust_remote_code=trust_remote_code, + use_fast_tokenizer=use_fast_tokenizer, + gguf_file=gguf_file, + add_bos_token=add_bos_token, + ) + + # if we passed `pretrained` as a string, initialize our model now + if isinstance(pretrained, str): + self._create_model( + pretrained=pretrained, + revision=revision, + dtype=dtype, + trust_remote_code=trust_remote_code, + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + peft=peft, + delta=delta, + autogptq=autogptq, + gptqmodel=gptqmodel, + gguf_file=gguf_file, + quantization_config=getattr(self.config, "quantization_config", None), + subfolder=subfolder, + **kwargs, + ) + + # access self._model through self.model property outside this method + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + # select (or create) a pad token to use + self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) + + self.add_bos_token = add_bos_token + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." + ) + + self._max_length = max_length + self.pretrained = pretrained + self.delta = delta + self.peft = peft + self.revision = revision + self.batch_schedule = 1 + self.batch_sizes = {} + self.max_batch_size = max_batch_size + self.softmax_dtype = ( + get_dtype(softmax_dtype) if softmax_dtype is not None else None + ) + + if str(batch_size).startswith("auto"): + batch_size = batch_size.split(":") + self.batch_size_per_gpu = batch_size[0] + self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 + else: + self.batch_size_per_gpu = int(batch_size) + + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + # TODO: can remove this whole snippet except in the mps case, perhaps? + if not (parallelize or autogptq or hasattr(self, "accelerator")): + # place model onto device requested manually, + # if not using HF Accelerate or device_map + # or any other option that preloads model onto device + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." + ) + # multigpu data-parallel support when launched with accelerate + if gpus > 1: + if accelerator.num_processes > 1: + if parallelize: + eval_logger.warning( + "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." + ) + elif gpus > accelerator.num_processes: + eval_logger.warning( + "WARNING: The number of total system GPUs does not match the number of spawned processes. " + "If you would like to use data parallelism, please launch the script " + "with 'accelerate launch *script*'. " + f"Current run will proceed with {accelerator.num_processes} devices." + ) + if self.accelerator.is_local_main_process: + eval_logger.info( + f"Using {gpus} devices with data parallelism" + ) + + self._device = torch.device(f"{accelerator.device}") + self.accelerator = accelerator + + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + # if we aren't launching via accelerate, ditch + self._rank = 0 + self._world_size = 1 + else: + # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" + ) + self._rank = 0 + self._world_size = 1 + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + + def _get_accelerate_args( + self, + parallelize: Optional[bool] = None, + device_map: Optional[str] = "auto", + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + gpus: Optional[int] = None, + ) -> dict: + """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" + num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes + if ( + num_machines == 0 + and hasattr(self, "accelerator") + and self.accelerator is not None + ): + eval_logger.info( + "We are not in a distributed setting for accelerate. Setting model_parallel to False." + ) + parallelize = False + + if parallelize is None: + # If parallelism is unset by the user, we automatically assign model parallelism + # if enough extra GPUs are available + max_memory_all_gpus = get_max_memory() + # We just want gpu, not cpu, max memory + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + parallelize = bool(num_local_processes < len(max_memory_all_gpus)) + eval_logger.info( + f"Setting model parallel to {parallelize} since " + f"the number of local processes is {num_local_processes} " + f"and the number of GPUs is {len(max_memory_all_gpus)}" + ) + + args = {} + if parallelize: # Model parallelism will be used + max_memory = {} + if max_memory_per_gpu is not None: # Using the provided memory requirements + max_memory_per_gpu_map = { + device_idx: max_memory_per_gpu for device_idx in range(gpus) + } + else: # Estimating the possible memory requirements + max_memory_all_gpus = get_max_memory() + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + if not hasattr(self, "accelerator"): + max_memory_per_gpu_map = { + k: v for k, v in max_memory_all_gpus.items() + } + else: + # use only 1 / num_processes of the GPUs if we are running under accelerate launch + max_memory_per_gpu_map = { + k: v + for k, v in max_memory_all_gpus.items() + if k % num_local_processes + == (self.accelerator.process_index % num_local_processes) + } + args["max_memory"] = max_memory_per_gpu_map + args["device_map"] = "auto" if device_map is None else device_map + eval_logger.info( + f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}" + ) + + if max_cpu_memory is not None: + max_memory["cpu"] = max_cpu_memory + + args["offload_folder"] = offload_folder + elif ( + device_map is None + ): # No model parallelism, we use the default provided device for our model + if hasattr(self, "accelerator"): + device_map = {"": f"{self.accelerator.device}"} + else: + device_map = {"": str(self.device)} + args["max_memory"] = None + args["device_map"] = device_map + eval_logger.info( + f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}" + ) + else: + args["max_memory"] = None + args["device_map"] = None + eval_logger.info("Model parallel was set to False.") + + return args + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: # if max length manually set, return it + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _get_backend( + self, + config: Union[transformers.PretrainedConfig, transformers.AutoConfig], + backend: Literal["default", "causal", "seq2seq"] = "default", + trust_remote_code: Optional[bool] = False, + ) -> None: + """ + Helper method during initialization. + Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used. + sets `self.AUTO_MODEL_CLASS` appropriately if not already set. + + **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM, + user must set `self.backend` to be either "causal" or "seq2seq" manually!** + """ + + assert backend in ["default", "causal", "seq2seq"] + + if backend != "default": + # if we've settled on non-default backend, use that manually + if backend == "causal": + self.backend = backend + elif backend == "seq2seq": + self.backend = backend + eval_logger.info( + f"Overrode HF model backend type, and using type '{self.backend}'" + ) + else: + # determine and use the default HF backend for this model, based on its config + metadata. + if ( + getattr(config, "model_type") + in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + ): + # first check if model type is listed under seq2seq models, since some + # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. + # these special cases should be treated as seq2seq models. + self.backend = "seq2seq" + eval_logger.debug(f"Using model type '{self.backend}'") + elif ( + getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "causal" + eval_logger.debug(f"Using model type '{self.backend}'") + else: + if not trust_remote_code: + eval_logger.warning( + "HF model type is neither marked as CausalLM or Seq2SeqLM. \ + This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." + "Setting backend to causal" + ) + # if model type is neither in HF transformers causal or seq2seq model registries + # then we default to assuming AutoModelForCausalLM + self.backend = "causal" + eval_logger.info( + f"Model type cannot be determined. Using default model type '{self.backend}'" + ) + + if self.AUTO_MODEL_CLASS is None: + if self.backend == "causal": + self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM + elif self.backend == "seq2seq": + self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM + + def _get_config( + self, + pretrained: str, + revision: str = "main", + trust_remote_code: bool = False, + gguf_file: Optional[str] = None, + subfolder: str = "", + ) -> None: + """Return the model config for HuggingFace models""" + self._config = transformers.AutoConfig.from_pretrained( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + subfolder=subfolder, + ) + + def _create_model( + self, + pretrained: str, + revision: Optional[str] = "main", + dtype: Optional[Union[str, torch.dtype]] = "auto", + trust_remote_code: Optional[bool] = False, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + # (accelerate naive PP (device_map) options) + parallelize: Optional[bool] = False, + gpus: Optional[int] = None, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + quantization_config: Optional[Dict[str, Any]] = None, + subfolder: str = "", + **kwargs, + ) -> None: + """ + Initializes an HF or HF-compatible PreTrainedModel from scratch + inside HFLM, using the kwargs passed into self.__init__(). + + Also handles functionality such as AutoGPTQ usage and PEFT wrapping. + + For future similar extensions to AutoGPTQ that are not core to HF's ecosystem, + (such as PyTorch models that are nearly, but not quite, fully mirroring + HF's public interface relied on in this HFLM class) + please consider subclassing HFLM and overriding this and other methods as needed. + """ + + model_kwargs = kwargs if kwargs else {} + + model_kwargs.update( + self._get_accelerate_args( + parallelize=parallelize, + device_map=kwargs.get("device_map", None), + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + gpus=gpus, + ) + ) + + if not autogptq and not gptqmodel: + if model_kwargs.get("load_in_4bit", None): + assert transformers.__version__ >= "4.30.0", ( + "load_in_4bit requires transformers >= 4.30.0" + ) + if transformers.__version__ >= "4.30.0": + if model_kwargs.get("load_in_4bit", None): + if model_kwargs.get("bnb_4bit_compute_dtype", None): + model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( + model_kwargs["bnb_4bit_compute_dtype"] + ) + + self._model = self.AUTO_MODEL_CLASS.from_pretrained( + pretrained, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + quantization_config=quantization_config, + subfolder=subfolder, + **model_kwargs, + ) + else: + if autogptq and gptqmodel: + raise ValueError( + "Cannot use both 'autogptq' and 'gptqmodel' options at the same time." + ) + + if autogptq: + try: + from auto_gptq import AutoGPTQForCausalLM + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load auto_gptq, but auto-gptq is not installed ", + "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", + ) + + self._model = AutoGPTQForCausalLM.from_quantized( + pretrained, + trust_remote_code=trust_remote_code, + model_basename=None if autogptq is True else Path(autogptq).stem, + use_safetensors=True + if autogptq is True + else autogptq.endswith(".safetensors"), + **model_kwargs, + ) + + if gptqmodel: + try: + from gptqmodel import GPTQModel + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load gptqmodel, but gptqmodel is not installed ", + "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`", + ) + + self._model = GPTQModel.from_quantized( + pretrained, trust_remote_code=trust_remote_code, **model_kwargs + ) + + if peft and delta: + raise ValueError( + "Cannot use both 'peft' and 'delta' options at the same time." + ) + + if peft: + if model_kwargs.get("load_in_4bit", None): + if version.parse(PEFT_VERSION) < version.parse("0.4.0"): + raise AssertionError("load_in_4bit requires peft >= 0.4.0") + if self._model.config.vocab_size != len(self.tokenizer): + # resize model for LoRAs with added tokens + eval_logger.info( + f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." + ) + self._model.resize_token_embeddings(len(self.tokenizer)) + self._model = PeftModel.from_pretrained( + self._model, peft, revision=revision + ) + elif delta: + if autogptq: + eval_logger.warning( + "Delta weights might trigger unexpected behavior when used with AutoGPTQ." + ) + _model_delta = self.AUTO_MODEL_CLASS.from_pretrained( + delta, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + for name, param in self._model.state_dict().items(): + try: + param.data += _model_delta.state_dict()[name] + except KeyError: + raise KeyError(f"Delta model is missing weights for layer: {name}") + except Exception as e: + raise RuntimeError( + f"Failed to add delta weights to layer {name}. Error: {e}" + ) + + del _model_delta + + return None + + def _create_tokenizer( + self, + pretrained: Union[str, transformers.PreTrainedModel], + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ], + revision: Optional[str] = "main", + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + gguf_file: Optional[str] = None, + add_bos_token: Optional[bool] = False, + subfolder: Optional[str] = "", + ) -> None: + """ + Helper method during initialization. + + Create a tokenizer object corresponding to the correct + tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. + """ + kwargs = { + "revision": revision, + "trust_remote_code": trust_remote_code, + } + + # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param + if gguf_file is not None: + kwargs["gguf_file"] = gguf_file + else: + kwargs["use_fast"] = use_fast_tokenizer + + if add_bos_token: + kwargs["add_bos_token"] = True + + if subfolder: + kwargs["subfolder"] = subfolder + + if tokenizer: + if isinstance(tokenizer, str): + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer, **kwargs + ) + else: + assert isinstance( + tokenizer, transformers.PreTrainedTokenizer + ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) + self.tokenizer = tokenizer + else: + # Get tokenizer based on 'pretrained' + if isinstance(pretrained, str): + model_name = pretrained + else: + # get the HF hub name via accessor on model + model_name = self.model.name_or_path + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, **kwargs + ) + return None + + def _detect_batch_size(self, requests=None, pos: int = 0): + if requests: + _, context_enc, continuation_enc = requests[pos] + max_length = len( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + ) + max_context_enc = len(context_enc[-(self.max_length + 1) :]) + max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) + else: + max_length = self.max_length + max_context_enc = max_length + max_cont_enc = max_length + + # if OOM, then halves batch_size and tries again + @find_executable_batch_size(starting_batch_size=self.max_batch_size) + def forward_batch(batch_size): + if self.backend == "seq2seq": + length = max(max_context_enc, max_cont_enc) + batched_conts = torch.ones( + (batch_size, length), device=self.device + ).long() + test_batch = torch.ones((batch_size, length), device=self.device).long() + call_kwargs = { + "attn_mask": test_batch, + "labels": batched_conts, + } + else: + call_kwargs = {} + test_batch = torch.ones( + (batch_size, max_length), device=self.device + ).long() + for _ in range(5): + out = F.log_softmax( # noqa: F841 + self._model_call(test_batch, **call_kwargs), + dim=-1, + dtype=self.softmax_dtype, + ) + + return batch_size + + try: + batch_size = forward_batch() + except RuntimeError as e: + if "No executable batch size found" in str(e): + batch_size = 1 + else: + raise + + if self.world_size > 1: + # if multi-GPU, always take minimum over all selected batch sizes + max_rnk_bs = torch.tensor([batch_size], device=self.device) + gathered = ( + self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() + ) + batch_size = min(gathered) + clear_torch_cache() + return batch_size + + clear_torch_cache() + return batch_size + + def tok_encode( + self, string: str, left_truncate_len=None, add_special_tokens=None + ) -> List[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs = { + "add_special_tokens": False or self.add_bos_token + } + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: List[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self.backend == "causal": + add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + eval_logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][ + :, -left_truncate_len: + ] + self.tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _model_call(self, inps, attn_mask=None, labels=None): + """ + :param inps: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attn_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + with torch.no_grad(): + if attn_mask is not None or labels is not None: + assert attn_mask is not None and labels is not None + assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM + return self.model( + input_ids=inps, attention_mask=attn_mask, labels=labels + ).logits + else: + assert self.AUTO_MODEL_CLASS in ( + transformers.AutoModelForCausalLM, + transformers.AutoModelForVision2Seq, + ) + return self.model(inps).logits + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = stop_sequences_criteria( + self.tokenizer, stop, context.shape[1], context.shape[0] + ) + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=True, + **generation_kwargs, + ) + + def _select_cont_toks( + self, logits: torch.Tensor, contlen: int = None, inplen: int = None + ) -> torch.Tensor: + if self.backend == "causal": + assert contlen and inplen, ( + "Must pass input len and cont. len to select scored logits for causal LM" + ) + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self.backend == "seq2seq": + assert contlen and not inplen, ( + "Selecting scored logits for Seq2SeqLM requires only cont. len" + ) + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[float]: + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + rolling_token_windows: List[Tuple[List[int], List[int]]] = list( + map( + utils.make_disjoint_window, + utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self.device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self.batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial( + "loglikelihood_rolling", (string,), request_total + ) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self.batch_schedule) + if sched in self.batch_sizes: + return self.batch_sizes[sched] + if (len(self.batch_sizes) > 1) and ( + self.batch_sizes[sched - 1] == self.max_batch_size + ): + # if previous batch size is already maximal, skip recomputation + self.batch_sizes[sched] = self.max_batch_size + return self.batch_sizes[sched] + print( + f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" + ) + self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self.batch_sizes[sched]}") + return self.batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: List[Tuple[Tuple[str, str], List[int], List[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> List[Tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key to group and lookup one-token continuations""" + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + return req[-2] + req[-1][:-1] + + re_ord = Collator( + requests, + sort_fn=_collate, + group_by="contexts" + if self.backend == "causal" and self.logits_cache + else None, + group_fn=_lookup_one_token_cont, + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else override_bs + if override_bs is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" + and n_reordered_requests > 0 + and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self.backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + eval_logger.warning( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + elif self.backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self.device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = ( + max(padding_len_cont, contlen) + if padding_len_cont is not None + else contlen + ) + + padding_len_inp = ( + max(padding_len_inp, inplen) + if padding_len_inp is not None + else inplen + ) + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self.backend == "causal": + batched_inps = pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self.backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = pad_and_concat( + padding_len_inp, inps + ) # [batch, padding_len_inp] + batched_conts = pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attn_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), + dim=-1, + dtype=self.softmax_dtype, + ) # [batch, padding_length (inp or cont), vocab] + + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = ( + inplen + (logits.shape[0] - padding_len_inp) + if self.backend == "causal" + else None + ) + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + cont_toks = torch.tensor( + cont_toks, dtype=torch.long, device=self.device + ).unsqueeze(0) # [1, seq] + # Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]). + # i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens + # by choosing key with longest cont if group_by="contexts". + max_equal = ( + greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks + ).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( + -1 + ) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial( + "loglikelihood", request_str, answer + ) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[str]: + res = [] + + def _collate(req: Tuple[str, dict]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(req[0]) + return -len(toks), req[0] + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else adaptive_batch_size + if adaptive_batch_size is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" and not adaptive_batch_size + else None + ) + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + re_ords = Collator( + [reg.args for reg in requests], + sort_fn=_collate, + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError( + f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" + ) + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + # set the max length in tokens of inputs ("context_enc") + if self.backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert max_ctx_len > 0, ( + f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + ) + elif self.backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + context_enc, attn_masks = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self.truncation, + ) + context_enc = context_enc.to(self.device) + attn_masks = attn_masks.to(self.device) + + if "max_length" not in kwargs: + kwargs["max_length"] = context_enc.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + context=context_enc, + attention_mask=attn_masks, + stop=until, + **kwargs, + ) + + cont_toks_list = cont.tolist() + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self.backend == "causal": + cont_toks = cont_toks[context_enc.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning( + "Failed to apply chat template. removing the system role in chat history." + ) + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated + + def get_model_info(self) -> dict: + """ + Method to get Hugging Face model information for experiment reproducibility. + """ + + def get_model_num_params(model) -> int: + if hasattr(model, "num_parameters"): + return model.num_parameters() + if hasattr(model, "parameters"): + return sum(p.numel() for p in model.parameters()) + else: + return -1 + + def get_model_dtype(model) -> str: + if hasattr(model, "dtype"): + return model.dtype + else: + return "" + + def get_model_sha(pretrained: str, revision: str) -> str: + try: + model_info = HfApi().model_info(repo_id=pretrained, revision=revision) + return model_info.sha + except Exception as e: + eval_logger.debug( + f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}" + ) + return "" + + model_info = { + "model_num_parameters": get_model_num_params(self._model), + "model_dtype": get_model_dtype(self._model), + "model_revision": self.revision, + "model_sha": get_model_sha(self.pretrained, self.revision), + } + if self.peft: + model_info["peft_sha"] = get_model_sha(self.peft, self.revision) + if self.delta: + model_info["delta_sha"] = get_model_sha(self.delta, self.revision) + return model_info diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e17fa224b22fbbef442c94e13d4f7c237d3c647d --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/utils.py @@ -0,0 +1,854 @@ +import collections +import fnmatch +import gc +import itertools +import logging +import time +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) + +import torch +import transformers + + +eval_logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from PIL import Image + from transformers import PreTrainedTokenizerBase + from transformers.configuration_utils import PretrainedConfig + + +def chunks(iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + for i, x in enumerate(iter): + arr.append(x) + if len(arr) == (fn(i, iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +class MultiChoice: + def __init__(self, choices) -> None: + self.choices = choices + + # Simple wildcard support (linux filename patterns) + def __contains__(self, values) -> bool: + for value in values.split(","): + if len(fnmatch.filter(self.choices, value)) == 0: + eval_logger.info("Available tasks to choose:") + for choice in self.choices: + eval_logger.info(f" - {choice}") + raise ValueError("'{}' is not in task list".format(value)) + return True + + def __iter__(self) -> Iterator: + for choice in self.choices: + yield choice + + +class Grouper: + """ + takes an array `arr` and function `fn` and returns a dictionary + with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all + objects in `arr` satisfying `key == fn(ob)`. + """ + + def __init__(self, arr, fn) -> None: + # self.orig_arr = arr + self.size = len(arr) + arr = list(enumerate(arr)) + + def group_return_dict(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + return res + + arr = group_return_dict(arr, lambda x: fn(x[1])) + + # self.arr has format Dict[Tuple[int, ]] + self.arr = arr + self._grouped = None + + def get_grouped(self): + # return the contents but not indices for our grouped dict. + if self._grouped: + return self._grouped + grouped = {} + for key in self.arr.keys(): + # drop the index from each element of self.arr + grouped[key] = [y[1] for y in self.arr[key]] + self._grouped = grouped + return grouped + + def get_original(self, grouped_dict): + # take in a grouped dictionary with e.g. results for each key listed + # in the same order as the instances in `self.arr`, and + # return the results in the same (single list) order as `self.orig_arr`. + res = [None] * self.size + cov = [False] * self.size + # orig = [None] * self.size + + assert grouped_dict.keys() == self.arr.keys() + + for key in grouped_dict.keys(): + for (ind, _), v in zip(self.arr[key], grouped_dict[key]): + res[ind] = v + cov[ind] = True + # orig[ind] = _ + + assert all(cov) + # assert orig == self.orig_arr + + return res + + +def pad_and_concat( + max_length: int, + tensors: List[torch.Tensor], + padding_side: Literal["right", "left"] = "right", +): + """ + Method for padding a list of tensors given the maximum tensor + length in the batch. Used for batching inputs and continuations in + seq2seq models. + """ + assert padding_side == "left" or padding_side == "right", ( + f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" + ) + + for i, tensor in enumerate(tensors): + if len(tensor.shape) == 2: + tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size + tensor_len = tensor.shape[0] + if tensor_len < max_length: + if padding_side == "right": + # right-pad + tensors[i] = torch.cat( + [ + tensor, # [seq] + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + ], + dim=0, + ).unsqueeze(0) + else: + # left-pad + tensors[i] = torch.cat( + [ + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + tensor, # [seq] + ], + dim=0, + ).unsqueeze(0) + else: + tensors[i] = tensor.unsqueeze(0) + + return torch.cat(tensors, dim=0) + + +def clear_torch_cache() -> None: + gc.collect() + torch.cuda.empty_cache() + + +def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: + """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" + if isinstance(dtype, str) and dtype != "auto": + # Convert `str` args torch dtype: `float16` -> `torch.float16` + _torch_dtype = getattr(torch, dtype) + else: + _torch_dtype = dtype + return _torch_dtype + + +class MultiTokenEOSCriteria(transformers.StoppingCriteria): + """Criteria to stop on the specified multi-token sequence.""" + + def __init__( + self, + sequence: str, + tokenizer: transformers.PreTrainedTokenizer, + initial_decoder_input_length: int, + batch_size: int, + ) -> None: + self.initial_decoder_input_length = initial_decoder_input_length + self.done_tracker = [False] * batch_size + self.sequence = sequence + self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) + # print(sequence, self.sequence_ids) + # we look back for 2 more tokens than it takes to encode our stop sequence + # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` + # and we don't want to mistakenly not stop a generation because our + # (string) stop sequence was output in a different tokenization + + # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, + # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized + # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described. + self.sequence_id_len = len(self.sequence_ids) + 2 + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs) -> bool: + # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence + lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :] + + lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :] + + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) + + for i, done in enumerate(self.done_tracker): + if not done: + self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] + return False not in self.done_tracker + + +def stop_sequences_criteria( + tokenizer: transformers.PreTrainedTokenizer, + stop_sequences: List[str], + initial_decoder_input_length: int, + batch_size: int, +) -> transformers.StoppingCriteriaList: + return transformers.StoppingCriteriaList( + [ + *[ + MultiTokenEOSCriteria( + sequence, tokenizer, initial_decoder_input_length, batch_size + ) + for sequence in stop_sequences + ], + ] + ) + + +def undistribute(iterable): + """ + Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute . + + Re-interleaves results that have been split using more_itertools.distribute: + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + >>> undistribute([group_1, group_2]) + [1, 2, 3, 4, 5, 6] + + Handles non-uniform component lengths: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + >>> undistribute(children) + [1, 2, 3, 4, 5, 6, 7] + + Also handles when some iterables are empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + >>> undistribute(children) + [1, 2, 3] + + """ + + return [ + x + for x in itertools.chain.from_iterable( + itertools.zip_longest(*[list(x) for x in iterable]) + ) + if x is not None + ] + + +def retry_on_specific_exceptions( + on_exceptions: List[Type[Exception]], + max_retries: Optional[int] = None, + backoff_time: float = 3.0, + backoff_multiplier: float = 1.5, + on_exception_callback: Optional[Callable[[Exception, float], Any]] = None, +): + """Retry on an LLM Provider's rate limit error with exponential backoff + For example, to use for OpenAI, do the following: + ``` + from openai import RateLimitError + + # Recommend specifying max_retries to avoid infinite loops! + @retry_on_specific_exceptions([RateLimitError], max_retries=3) + def completion(...): + # Wrap OpenAI completion function here + ... + ``` + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + sleep_time = backoff_time + attempt = 0 + while max_retries is None or attempt < max_retries: + try: + return func(*args, **kwargs) + except tuple(on_exceptions) as e: + if on_exception_callback is not None: + on_exception_callback(e, sleep_time) + time.sleep(sleep_time) + sleep_time *= backoff_multiplier + attempt += 1 + + return wrapper + + return decorator + + +class Collator: + """ + A class for reordering and batching elements of an array. + + This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. + + Objects of this class have the group_by attribute which determines the method for grouping + the data while batching it. Three options include "gen_kwargs", "contexts", or None: + If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs + If group_by == "contexts" then requests will be grouped by context + cont[:-1] + If None then requests will just be reordered by length descending. + """ + + def __init__( + self, + arr: List, + sort_fn: Callable = lambda x: x, + group_fn: Callable = lambda x: x[1], + group_by: Union[Literal["gen_kwargs", "contexts"], None] = None, + ) -> None: + self._group_by = group_by + # 0 indices are enumerated indices. Apply functions to original arr. + self._sort_fn = lambda x: sort_fn(x[1]) + self._group_fn = lambda x: group_fn(x[1]) + self._reorder_indices: List = [] + self._size = len(arr) + self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple( + enumerate(arr) + ) # [indices, (arr)] + if self._group_by == "contexts": + self._group_by_context() + elif self._group_by == "gen_kwargs": + self._group_by_index() + + def _group_by_index(self) -> None: + """Group the elements of a list based on their indices.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs" + ) + + def _group_by_context(self) -> None: + """Group the array with indices by context.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="contexts" + ) + + def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: + """ + Generates and yields batches from the reordered array. The method of grouping and batching + depends on the parameter `group_by`. + If `group_by` is set to "gen_kwargs", it will batch the + re-ordered values with same gen_kwargs for each batch. + If `group_by` is "contexts", it caches the requests by context before batching. + If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array + + Parameters: + - n (int): The size of each batch. Defaults to 1. + - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of + each batch. Optional, defaults to None. + + Returns: + Iterator: An iterator over batches of reordered elements grouped as per the `group_by` + attribute. + + Yields: + List of batched elements according to the `group_by` attribute. + """ + if self._group_by == "gen_kwargs": + for ( + key, + values, + ) in self._arr_with_indices.items(): # type: ignore + values = self._reorder(values) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + elif self._group_by == "contexts": + # Get one sample from each key. + # Select longest continuation per group to ensure sufficient context logits + values = self._reorder( + [ + max(value, key=lambda x: len(x[1][-1])) + for value in self._arr_with_indices.values() + ] + ) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + else: + values = self._reorder(self._arr_with_indices) # type: ignore + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + + def get_cache( + self, + req_str: Tuple[str, str] = None, + cxt_toks: List[int] = None, + cont_toks: List[int] = None, + logits: torch.Tensor = None, + ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]: + """ + Retrieves cached single-token continuations and their associated arguments, updating indices as necessary. + + The behavior of this function varies depending on how the `group_by` attribute is set: + + - When `group_by` is "contexts": + The function identifies single-token continuations by checking for keys that equate to + [context+continuation][-1] and logs the indices for re-ordering. + In this mode, this function can work in two scenarios: + + 1. Cache Hit - Single Match: + If a single matching context-continuation pair is found in the cache, + the function yields the original arguments. + + 2. Cache Hit - Multiple Matches: + If multiple matching context-continuation pairs are found in the cache, + the function expands the logits batch dimension to match the number of cache hits. + It updates the original requests and continuation tokens. + + - When `group_by` is not set to "contexts": + This method yields the original arguments, logits and continuation tokens, + without checking for one-token continuations. + + Parameters: + - req_str (tuple[str, str]): Original strings used for CachingLM. + - cxt_toks (list[int]): Full context tokens used for lookup. + - cont_toks (list[int]): Continuation tokens for which logits were generated. + - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys. + + Yields: + - Iterator: + - req_str (tuple[str, str]): strings used for CachingLM. + - cont_toks (list[int]) : continuation tokens. + - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times) + """ + if self._group_by == "contexts": + cache_hit: List[ + Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]] + ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1])) + if (cache_size := len(cache_hit)) == 1: + self._reorder_indices.extend(x[0] for x in cache_hit) + yield req_str, cont_toks, logits + else: + # If we have matching requests then expand the batch dimension (no-op) and + # yield each along with its corresponding args. + multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size) + indices, req_str, cont_toks = zip( + *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit] + ) + self._reorder_indices.extend(indices) + for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits): + yield c_key, cont_tok, logit + else: + yield req_str, cont_toks, logits + + def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator: + """ + Reorders the elements in the array based on the sorting function. + + Parameters: + - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered. + + Yields: + Iterator + """ + arr = sorted(arr, key=self._sort_fn) + if not self._group_by == "contexts": + # If grouped by contexts then indices will be set in get_cache() + self._reorder_indices.extend([x[0] for x in arr]) + yield from [x[1] for x in arr] + + def get_original(self, newarr: List) -> List: + """ + Restores the original order of elements from the reordered list. + + Parameters: + - newarr (list): The reordered array. + + Returns: + list: The array with elements restored to their original order. + """ + res = [None] * self._size + cov = [False] * self._size + + for ind, v in zip(self._reorder_indices, newarr): + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + + def __len__(self): + return self._size + + @staticmethod + def group( + arr: Iterable, + fn: Callable, + group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs", + ) -> dict: + """ + Groups elements of an iterable based on a provided function. + + + The `group_by` parameter determines the method of grouping. + If `group_by` is "contexts", the elements are grouped by [context + cont][:-1]. + If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict. + + Parameters: + - arr (Iterable): The iterable to be grouped. + - fn (Callable): The function to determine the grouping. + - values (bool): If True, returns the values of the group. Defaults to False. + + Returns: + Iterator: An iterable of grouped elements. + """ + res = collections.defaultdict(list) + for ob in arr: + # where ob == [context + cont] + if group_by == "contexts": + res[tuple(fn(ob))].append(ob) + else: + try: + hashable_dict = tuple( + ( + key, + tuple(value) + if isinstance(value, collections.abc.Iterable) + else value, + ) + for key, value in sorted(fn(ob).items()) + ) + res[hashable_dict].append(ob) + except (TypeError, AttributeError): + res[tuple(fn(ob))].append(ob) + return res + + @staticmethod + def get_chunks(_iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + _iter = tuple(_iter) + for i, x in enumerate(_iter): + arr.append(x) + if len(arr) == (fn(i, _iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +def configure_pad_token( + tokenizer: "PreTrainedTokenizerBase", + model_config: Optional["PretrainedConfig"] = None, +) -> "PreTrainedTokenizerBase": + """ + This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present. + Some tokenizers require special handling. + + Args: + tokenizer: The tokenizer for which the padding token is to be handled. + model_config: The configuration of the model. Default is None. + + Returns: + The tokenizer after the padding token has been handled. + + Raises: + AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0. + """ + if tokenizer.pad_token: + pass + elif tokenizer.unk_token: + tokenizer.pad_token_id = tokenizer.unk_token_id + elif tokenizer.eos_token: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + # handle special cases + if model_config and getattr(model_config, "model_type", None) == "qwen": + # Qwen's trust_remote_code tokenizer does not allow for adding special tokens + tokenizer.pad_token = "<|endoftext|>" + elif ( + tokenizer.__class__.__name__ == "RWKVWorldTokenizer" + or tokenizer.__class__.__name__ == "Rwkv5Tokenizer" + ): + # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0) + # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer + # --- + # Note that the world tokenizer class name, might change in the future for the final huggingface merge + # https://github.com/huggingface/transformers/pull/26963 + assert tokenizer.pad_token_id == 0 + else: + tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + + return tokenizer + + +def replace_placeholders( + string: str, default_placeholder: str, image_token: str, max_images: int +): + """ + A utility function used for local multimodal models. It locates all `placeholder` string + occurrences in the given input `string_` and replaces the first `max_count` instances with + `replacement`, and all subsequent occurrences with the empty string. + + This is used to replace placeholder tags by model-specific image tokens like <|image_pad|> + and to allow for only the first `max_count` images to be passed to a model if desired. + + :param string: The original string containing placeholders. + :param default_placeholder: The placeholder text to be replaced. + :param image_token: The token to replace the placeholder with. + :param max_images: The maximum number of replacements to make. + :return: The string with placeholders replaced. + """ + count = 0 + result = [] + + parts = string.split(default_placeholder) + for part in parts[:-1]: # Iterate through all but the last part + result.append(part) + if count < max_images: + result.append(image_token) + count += 1 + elif default_placeholder != image_token: + result.append(default_placeholder) + + # Add the last part of the string + result.append(parts[-1]) + return "".join(result) + + +def flatten_image_list(images: List[List]): + """ + Takes in a list of lists of images, and returns a single list of all images in order. + Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor. + + :param images: A list of lists of PIL images. + :return: a list of PIL images, via concatenating all the sub-lists in order. + """ + return [image for image_list in images for image in image_list] + + +def handle_stop_sequences( + until: Union[str, List[str], None], eos: Optional[str] +) -> List[str]: + """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token.""" + if isinstance(until, str): + until = [until] + elif until is None: + until = [] + elif not isinstance(until, list): + raise ValueError( + f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" + ) + + if eos is not None and eos not in until: + until.append(eos) + return until + + +def resize_image( + image: "Image.Image", + width: Optional[int] = None, + height: Optional[int] = None, + max_dimension: Optional[int] = None, + keep_aspect_ratio: bool = True, + resample_filter: Union[int, str] = "Image.BICUBIC", + min_width: int = 1, + min_height: int = 1, +) -> "Image.Image": + """ + Resizes a PIL Image object with flexible options. + + Args: + image: The PIL Image object to resize. + width: Target width in pixels. + height: Target height in pixels. + max_dimension: Maximum size for the longer dimension of the image. + keep_aspect_ratio: If True (default) and both width and height are provided, + the image is resized to fit within these dimensions while + maintaining its aspect ratio. If False, the image is stretched + to the exact width and height. + resample_filter: The resampling filter to use for resizing. + Defaults to Image.BICUBIC. + min_width: Minimum width for the resized image. Defaults to 1. + min_height: Minimum height for the resized image. Defaults to 1. + + Returns: + The resized PIL Image object. If no resize parameters are provided + or if the image already meets the criteria, the original image is returned. + + Order of precedence for resizing: + 1. If width AND height are provided: + - If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio. + - If keep_aspect_ratio is False: Resizes to exact dimensions (may distort). + 2. Else if only width is provided: Calculates height proportionally. + 3. Else if only height is provided: Calculates width proportionally. + 4. Else if max_dimension is provided: Resizes the longest side to max_dimension + and scales the other side proportionally. + 5. If none of the above are provided, returns the original image. + """ + original_width, original_height = image.size + + # If no arguments are provided, return the original image + if width is None and height is None and max_dimension is None: + return image + + new_width = original_width + new_height = original_height + + if width is not None and height is not None: + # No resize needed if image is already smaller than target dimensions + if original_width <= width and original_height <= height: + return image + + if keep_aspect_ratio: + # Calculate the ratio to fit within the target dimensions + ratio = min(width / original_width, height / original_height) + new_width = int(original_width * ratio) + new_height = int(original_height * ratio) + else: + # Stretch to exact dimensions + new_width = width + new_height = height + elif width is not None: + # No resize needed if width is already smaller + if original_width <= width: + return image + # Calculate height proportionally + new_width = width + new_height = int((original_height / original_width) * new_width) + elif height is not None: + # No resize needed if height is already smaller + if original_height <= height: + return image + # Calculate width proportionally + new_height = height + new_width = int((original_width / original_height) * new_height) + elif max_dimension is not None: + # No resize needed if both dimensions are smaller than max_dimension + if max(original_height, original_width) <= max_dimension: + return image + + if original_width > original_height: + # Width is the longer side + new_width = max_dimension + new_height = int((original_height / original_width) * new_width) + else: + # Height is the longer side or sides are equal + new_height = max_dimension + new_width = int((original_width / original_height) * new_height) + + # Ensure dimensions are at least minimum values + new_width = max(min_width, new_width) + new_height = max(min_height, new_height) + + # Perform the resize operation with the calculated dimensions + return image.resize((new_width, new_height), resample_filter) + + +def truncate_tokens( + tokens: List[int], + max_length: int, + tokenizer: "PreTrainedTokenizerBase", + strategy: str = "left", +): + if strategy == "left": + return tokens[-max_length:] + elif strategy == "right": + return tokens[:max_length] + elif strategy == "middle": + # Truncate the middle of the sequence + left_length = max_length // 2 + right_length = max_length - left_length + return tokens[:left_length] + tokens[-right_length:] + return None diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/verifier.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..28f7ec6cc78d40d8df7e8d894d8d8d83222bffa7 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/models/verifier.py @@ -0,0 +1,147 @@ +import torch +import logging +import ast +import re +import numpy as np +import textwrap + +logger = logging.getLogger(__name__) + +class CodeVerifier: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + + self.yes_ids, self.no_ids = [], [] + for t in ["Yes", " Yes"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) == 1: self.yes_ids.append(ids[0]) + for t in ["No", " No"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) == 1: self.no_ids.append(ids[0]) + + def _extract_python_code(self, text): + text = text.strip() + match = re.search(r"```python\s*(.*?)```", text, re.DOTALL) + if match: return match.group(1) + match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL) + if match_generic: return match_generic.group(1) + return text + + def check_syntax(self, code_str): + clean_code = self._extract_python_code(code_str) + try: + if len(clean_code.strip()) < 5: return False + ast.parse(clean_code) + return True + except: + return False + + def compute_confidence(self, logits): + if logits is None: return 0.0 + probs = torch.softmax(logits, dim=-1) + max_probs, _ = torch.max(probs, dim=-1) + log_probs = torch.log(max_probs + 1e-10) + return torch.exp(torch.mean(log_probs)).item() + + def svf_score(self, prompt, code_str, task_type="code"): + + max_len = 2000 + if len(code_str) > max_len: + if task_type == "reasoning": + truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):] + else: + truncated_code = code_str[-max_len:] + else: + truncated_code = code_str + + if task_type == "code": + prompt_template = f""" + You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints. + + [Problem Statement] + {prompt} + [/Problem Statement] + + [Proposed Python Solution] + ```python + {truncated_code} + ``` + [/Proposed Python Solution] + + **Analysis Steps:** + 1. Correctness: Does the core algorithm correctly solve the problem? + 2. Efficiency: Is the time complexity acceptable for the given constraints? + 3. Edge Cases & Constraints: Does the code handle all rules and edge cases? + + **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "math": + prompt_template = f""" + You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy. + + [Math Problem] + {prompt} + [/Math Problem] + + [Proposed Mathematical Solution] + {truncated_code} + [/Proposed Mathematical Solution] + + **Analysis Steps:** + 1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly? + 2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate? + 3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem? + + **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "reasoning": + prompt_template = f""" + You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question. + + [Context and Question] + {prompt} + [/Context and Question] + + [Proposed Answer] + {truncated_code} + [/Proposed Answer] + + **Analysis Steps :** + 1. Faithfulness: Is the answer an exact, literal span from the context? + 2. Relevance: Does the answer directly address the specific question asked without hallucinating external information? + 3. Accuracy: Does the provided context strictly support this answer? + + **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No. + **Answer:** """ + + else: + prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:" + + verify_text = textwrap.dedent(prompt_template).strip() + input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device) + + if input_ids.shape[1] > self.model.config.max_position_embeddings - 16: + logger.warning("Verifier input is too long, truncating from the left.") + input_ids = input_ids[:, - (self.model.config.max_position_embeddings - 16):] + + with torch.no_grad(): + outputs = self.model(input_ids) + logits = outputs.logits[0, -1, :] + + yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf')) + no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf')) + + if yes_score == -float('inf') and no_score == -float('inf'): return 0.5 + + probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0) + return probs[0].item() + + def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"): + if mode == "svf": + return self.svf_score(prompt, code_str, task_type=task_type) + else: + return self.compute_confidence(current_logits) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/prompts/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c2ce897dcde522ac82d0cbe0e06db1e02b1b72 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/prompts/__init__.py @@ -0,0 +1,128 @@ +import ast +import logging +import os +from typing import Dict + +from dllm_eval import utils + + +eval_logger = logging.getLogger(__name__) + +# Prompt library. +# Stores prompts in a dictionary indexed by 2 levels: +# prompt category name, and prompt name. +# This allows us to access prompts +PROMPT_REGISTRY: Dict[str, Dict[str, str]] = { + "qa-basic": { + "question-newline-answer": "Question: {{question}}\nAnswer:", + "q-newline-a": "Q: {{question}}\nA:", + }, +} + + +def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None): + # unpack prompt name + category_name, prompt_name = prompt_id.split(":") + if subset_name is None: + dataset_full_name = dataset_name + else: + dataset_full_name = f"{dataset_name}-{subset_name}" + eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}") + if category_name == "promptsource": + try: + from promptsource.templates import DatasetTemplates + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load a Promptsource template, but promptsource is not installed ", + "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]", + ) + try: + if subset_name is None: + prompts = DatasetTemplates(dataset_name=dataset_name) + else: + prompts = DatasetTemplates( + dataset_name=dataset_name, subset_name=subset_name + ) + except Exception: + raise ValueError(f"{dataset_name} and {subset_name} not found") + if prompt_name in prompts.all_template_names: + return prompts[prompt_name] + else: + raise ValueError( + f"{prompt_name} not in prompt list {prompts.all_template_names}" + ) + elif ".yaml" in category_name: + import yaml + + with open(category_name, "rb") as file: + prompt_yaml_file = yaml.full_load(file) + + prompt_string = prompt_yaml_file["prompts"][prompt_name] + return PromptString(prompt_string) + else: + try: + return PROMPT_REGISTRY[category_name][prompt_name] + except Exception: + raise ValueError( + f"expected only a single `:` as separator between \ + prompt category and name, but got `{prompt_id}` instead" + ) + + +def load_prompt_list( + use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs +): + category_name, prompt_name = use_prompt.split(":") + + if category_name == "promptsource": + from promptsource.templates import DatasetTemplates + + if subset_name is None: + prompts = DatasetTemplates(dataset_name=dataset_name) + else: + prompts = DatasetTemplates( + dataset_name=dataset_name, subset_name=subset_name + ) + + prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) + + elif ".yaml" in category_name: + import yaml + + if yaml_path is not None: + category_name = os.path.realpath(os.path.join(yaml_path, category_name)) + + with open(category_name, "rb") as file: + prompt_yaml_file = yaml.full_load(file) + + prompt_list = utils.pattern_match( + prompt_name, prompt_yaml_file["prompts"].keys() + ) + + # category_name, *prompt_name = use_prompt.split(":") + # TODO allow to multiple prompt naming + # if len(prompt_name) > 1: + # prompt_list = [] + # for prompt in prompt_name: + # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names)) + # else: + # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) + return [":".join([category_name, prompt]) for prompt in prompt_list] + + +class PromptString: + def __init__(self, prompt_string): + self.prompt_string = prompt_string + + def apply(self, doc): + doc_to_text = self.prompt_string["doc_to_text"] + doc_to_target = self.prompt_string["doc_to_target"] + + # TODO need a way to process doc_to_choice + if "doc_to_choice" in self.prompt_string: + raise NotImplementedError("Not yet implemented to accept doc_to_choice") + + text_string = utils.apply_template(doc_to_text, doc) + target_string = utils.apply_template(doc_to_target, doc) + + return [text_string, target_string] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73d896452e06c2cc2909c290de70dcf87b0c6f90 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/__init__.py @@ -0,0 +1,670 @@ +import collections +import inspect +import logging +import os +from functools import partial +from typing import Dict, List, Mapping, Optional, Union + +from dllm_eval import utils +from dllm_eval.api.group import ConfigurableGroup, GroupConfig +from dllm_eval.api.task import ConfigurableTask, Task +from dllm_eval.evaluator_utils import get_subtask_list + + +GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) + +eval_logger = logging.getLogger(__name__) + + +class TaskManager: + """TaskManager indexes all tasks from the default `dllm_eval/tasks/` + and an optional directory if provided. + + """ + + def __init__( + self, + verbosity: Optional[str] = None, + include_path: Optional[Union[str, List]] = None, + include_defaults: bool = True, + metadata: Optional[dict] = None, + ) -> None: + if verbosity is not None: + utils.setup_logging(verbosity) + self.include_path = include_path + self.metadata = metadata + self._task_index = self.initialize_tasks( + include_path=include_path, include_defaults=include_defaults + ) + self._all_tasks = sorted(list(self._task_index.keys())) + + self._all_groups = sorted( + [x for x in self._all_tasks if self._task_index[x]["type"] == "group"] + ) + self._all_subtasks = sorted( + [ + x + for x in self._all_tasks + if self._task_index[x]["type"] in ["task", "python_task"] + ] + ) + self._all_tags = sorted( + [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"] + ) + + self.task_group_map = collections.defaultdict(list) + + def initialize_tasks( + self, + include_path: Optional[Union[str, List]] = None, + include_defaults: bool = True, + ) -> dict[str, dict]: + """Creates a dictionary of tasks indexes. + + :param include_path: Union[str, List] = None + An additional path to be searched for tasks recursively. + Can provide more than one such path as a list. + :param include_defaults: bool = True + If set to false, default tasks (those in dllm_eval/tasks/) are not indexed. + return + Dictionary of task names as key and task metadata + """ + if include_defaults: + all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"] + else: + all_paths = [] + if include_path is not None: + if isinstance(include_path, str): + include_path = [include_path] + all_paths.extend(include_path) + + task_index = {} + for task_dir in all_paths: + tasks = self._get_task_and_group(task_dir) + task_index = {**tasks, **task_index} + + return task_index + + @property + def all_tasks(self): + return self._all_tasks + + @property + def all_groups(self): + return self._all_groups + + @property + def all_subtasks(self): + return self._all_subtasks + + @property + def all_tags(self): + return self._all_tags + + @property + def task_index(self): + return self._task_index + + def list_all_tasks( + self, list_groups=True, list_tags=True, list_subtasks=True + ) -> str: + from pytablewriter import MarkdownTableWriter + + def sanitize_path(path): + # don't print full path if we are within the dllm_eval/tasks dir ! + # if we aren't though, provide the full path. + if "dllm_eval/tasks/" in path: + return "dllm_eval/tasks/" + path.split("dllm_eval/tasks/")[-1] + else: + return path + + group_table = MarkdownTableWriter() + group_table.headers = ["Group", "Config Location"] + gt_values = [] + for g in self.all_groups: + path = self.task_index[g]["yaml_path"] + if path == -1: + path = "---" + else: + path = sanitize_path(path) + gt_values.append([g, path]) + group_table.value_matrix = gt_values + + tag_table = MarkdownTableWriter() + tag_table.headers = ["Tag"] + tag_table.value_matrix = [[t] for t in self.all_tags] + + subtask_table = MarkdownTableWriter() + subtask_table.headers = ["Task", "Config Location", "Output Type"] + st_values = [] + for t in self.all_subtasks: + path = self.task_index[t]["yaml_path"] + + output_type = "" + + # read the yaml file to determine the output type + if path != -1: + config = utils.load_yaml_config(path, mode="simple") + if "output_type" in config: + output_type = config["output_type"] + elif ( + "include" in config + ): # if no output type, check if there is an include with an output type + include_path = path.split("/")[:-1] + config["include"] + include_config = utils.load_yaml_config(include_path, mode="simple") + if "output_type" in include_config: + output_type = include_config["output_type"] + + if path == -1: + path = "---" + else: + path = sanitize_path(path) + st_values.append([t, path, output_type]) + subtask_table.value_matrix = st_values + + result = "\n" + if list_groups: + result += group_table.dumps() + "\n\n" + if list_tags: + result += tag_table.dumps() + "\n\n" + if list_subtasks: + result += subtask_table.dumps() + "\n\n" + return result + + def match_tasks(self, task_list: list[str]) -> list[str]: + return utils.pattern_match(task_list, self.all_tasks) + + def _name_is_registered(self, name: str) -> bool: + if name in self.all_tasks: + return True + return False + + def _name_is_task(self, name: str) -> bool: + if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"): + return True + return False + + def _name_is_tag(self, name: str) -> bool: + if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"): + return True + return False + + def _name_is_group(self, name: str) -> bool: + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "group" + ): + return True + return False + + def _name_is_python_task(self, name: str) -> bool: + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "python_task" + ): + return True + return False + + def _config_is_task(self, config: dict) -> bool: + if ("task" in config) and isinstance(config["task"], str): + return True + return False + + def _config_is_group(self, config: dict) -> bool: + if ("task" in config) and isinstance(config["task"], list): + return True + return False + + def _config_is_python_task(self, config: dict) -> bool: + if "class" in config: + return True + return False + + def _get_yaml_path(self, name: str): + if name not in self.task_index: + raise ValueError + return self.task_index[name]["yaml_path"] + + def _get_config(self, name): + if name not in self.task_index: + raise ValueError + yaml_path = self._get_yaml_path(name) + if yaml_path == -1: + return {} + else: + return utils.load_yaml_config(yaml_path, mode="full") + + def _get_tasklist(self, name): + if self._name_is_task(name): + raise ValueError + return self.task_index[name]["task"] + + def _process_alias(self, config, group=None): + # If the group is not the same as the original + # group which the group alias was intended for, + # Set the group_alias to None instead. + if ("group_alias" in config) and ("group" in config) and group is not None: + if config["group"] != group: + config["group_alias"] = None + return config + + def _class_has_config_in_constructor(self, cls): + constructor = getattr(cls, "__init__", None) + return ( + "config" in inspect.signature(constructor).parameters + if constructor + else False + ) + + def _load_individual_task_or_group( + self, + name_or_config: Optional[Union[str, dict]] = None, + parent_name: Optional[str] = None, + update_config: Optional[dict] = None, + ) -> Mapping: + def _load_task(config, task): + if "include" in config: + config = { + **utils.load_yaml_config( + yaml_path=None, + yaml_config={"include": config.pop("include")}, + mode="full", + ), + **config, + } + if self._config_is_python_task(config): + if self._class_has_config_in_constructor(config["class"]): + task_object = config["class"](config=config) + else: + task_object = config["class"]() + if isinstance(task_object, ConfigurableTask): + # very scuffed: set task name here. TODO: fixme? + task_object.config.task = task + else: + if self.metadata is not None: + config["metadata"] = config.get("metadata", {}) | self.metadata + else: + config["metadata"] = config.get("metadata", {}) + task_object = ConfigurableTask(config=config) + + return {task: task_object} + + def _get_group_and_subtask_from_config( + config: dict, + ) -> tuple[ConfigurableGroup, list[str]]: + if self.metadata is not None: + config["metadata"] = config.get("metadata", {}) | self.metadata + group_name = ConfigurableGroup(config=config) + subtask_list = [] + for task in group_name.config["task"]: + if isinstance(task, str) and self._name_is_tag(task): + subtask_list.extend(self._get_tasklist(task)) + else: + subtask_list.append(task) + return group_name, subtask_list + + def _process_group_config( + config: dict, update_config: dict = None + ) -> tuple[dict, dict]: + if update_config is not None: + config = {**config, **update_config} + _update_config = { + k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS + } + if not bool(_update_config): + _update_config = None + + group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS} + return group_config, _update_config + + if isinstance(name_or_config, str): + if update_config is not None: + # Process name_or_config as a dict instead + name_or_config = {"task": name_or_config, **update_config} + elif self._name_is_task(name_or_config) or self._name_is_python_task( + name_or_config + ): + task_config = self._get_config(name_or_config) + return _load_task(task_config, task=name_or_config) + else: + subtask_list = self._get_tasklist(name_or_config) + if subtask_list == -1: + group_config = self._get_config(name_or_config) + group_config, update_config = _process_group_config(group_config) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + else: + if self._name_is_tag(name_or_config): + fn = partial( + self._load_individual_task_or_group, + update_config=name_or_config + if isinstance(name_or_config, dict) + else None, + ) + return dict( + collections.ChainMap(*map(fn, reversed(subtask_list))) + ) + else: + group_name = ConfigurableGroup( + config={"group": name_or_config, "task": subtask_list} + ) + + if isinstance(name_or_config, dict): + if self._config_is_task(name_or_config): + name = name_or_config.pop("task") + if update_config is not None: + name_or_config = {**name_or_config, **update_config} + # If the name is registered as a group + if self._name_is_group(name): + group_config = self._get_config(name) + + group_config, update_config = _process_group_config( + group_config, name_or_config + ) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + elif self._name_is_tag(name): + subtask_list = self._get_tasklist(name) + fn = partial( + self._load_individual_task_or_group, + update_config=name_or_config, + ) + return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) + else: + if self._name_is_registered(name): + base_task_config = self._get_config(name) + + # Check if this is a duplicate. + if parent_name is not None: + num_duplicate = len( + list( + filter( + lambda x: x.startswith(name), + self.task_group_map[parent_name], + ) + ) + ) + if num_duplicate > 0: + name = f"{name}-{num_duplicate}" + self.task_group_map[parent_name].append(name) + + task_config = { + **base_task_config, + **name_or_config, + } + else: + task_config = name_or_config + return _load_task(task_config, task=name) + else: + group_config, update_config = _process_group_config(name_or_config) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + + fn = partial( + self._load_individual_task_or_group, + parent_name=group_name, + update_config=update_config, + ) + return { + group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) + } + + def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: + """Loads a dictionary of task objects from a list + + :param task_list: Union[str, list] = None + Single string or list of string of task names to be loaded + + :return + Dictionary of task objects + """ + if isinstance(task_list, str): + task_list = [task_list] + + all_loaded_tasks = dict( + collections.ChainMap( + *map( + lambda task: self._load_individual_task_or_group(task), + task_list, + ) + ) + ) + return all_loaded_tasks + + def load_config(self, config: Dict): + return self._load_individual_task_or_group(config) + + def _get_task_and_group(self, task_dir: str): + """Creates a dictionary of tasks index with the following metadata, + - `type`, that can be either `task`, `python_task`, `group` or `tags`. + `task` refer to regular task configs, `python_task` are special + yaml files that only consists of `task` and `class` parameters. + `group` are group configs. `tags` are labels that can be assigned + to tasks to assist in sorting and calling tasks of certain themes. + - `yaml_path`, path to the yaml file. If the entry is a `group` that + was configured through a task config, the yaml_path will be -1 + and all subtasks will be listed in `task` (see below) + - `task`, reserved for entries with `type` as `group`. This will list + all subtasks. When a group config is created (as opposed to task + config having `group` parameter set), this will be set to -1 to + avoid recursive indexing. The whole list of subtasks will be loaded + at evaluation. + + :param task_dir: str + A directory to check for tasks + + :return + Dictionary of task names as key and task metadata + """ + + def _populate_tags_and_groups(config, task, tasks_and_groups, print_info): + # TODO: remove group in next release + if "tag" in config: + attr_list = config["tag"] + if isinstance(attr_list, str): + attr_list = [attr_list] + + for tag in attr_list: + if tag not in tasks_and_groups: + tasks_and_groups[tag] = { + "type": "tag", + "task": [task], + "yaml_path": -1, + } + elif tasks_and_groups[tag]["type"] != "tag": + eval_logger.info( + f"The tag '{tag}' is already registered as a group, this tag will not be registered. " + "This may affect tasks you want to call." + ) + break + else: + tasks_and_groups[tag]["task"].append(task) + + # TODO: remove group in next release + print_info = True + ignore_dirs = [ + "__pycache__", + ".ipynb_checkpoints", + ] + tasks_and_groups = collections.defaultdict() + for root, dirs, file_list in os.walk(task_dir): + dirs[:] = [d for d in dirs if d not in ignore_dirs] + for f in file_list: + if f.endswith(".yaml"): + yaml_path = os.path.join(root, f) + print(yaml_path) + config = utils.load_yaml_config(yaml_path, mode="simple") + if self._config_is_python_task(config): + # This is a python class config + task = config["task"] + tasks_and_groups[task] = { + "type": "python_task", + "yaml_path": yaml_path, + } + _populate_tags_and_groups( + config, task, tasks_and_groups, print_info + ) + elif self._config_is_group(config): + # This is a group config + tasks_and_groups[config["group"]] = { + "type": "group", + "task": -1, # This signals that + # we don't need to know + # the task list for indexing + # as it can be loaded + # when called. + "yaml_path": yaml_path, + } + + # # Registered the level 1 tasks from a group config + # for config in config["task"]: + # if isinstance(config, dict) and self._config_is_task(config): + # task = config["task"] + # tasks_and_groups[task] = { + # "type": "task", + # "yaml_path": yaml_path, + # } + + elif self._config_is_task(config): + # This is a task config + task = config["task"] + tasks_and_groups[task] = { + "type": "task", + "yaml_path": yaml_path, + } + _populate_tags_and_groups( + config, task, tasks_and_groups, print_info + ) + else: + eval_logger.debug(f"File {f} in {root} could not be loaded") + + return tasks_and_groups + + +def get_task_name_from_config(task_config: Dict[str, str]) -> str: + if "task" in task_config: + return task_config["task"] + if "dataset_name" in task_config: + return "{dataset_path}_{dataset_name}".format(**task_config) + else: + return "{dataset_path}".format(**task_config) + + +def get_task_name_from_object(task_object): + if hasattr(task_object, "config"): + return task_object._config["task"] + + # TODO: scrap this + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting + return ( + task_object.EVAL_HARNESS_NAME + if hasattr(task_object, "EVAL_HARNESS_NAME") + else type(task_object).__name__ + ) + + +def _check_duplicates(task_dict: dict) -> None: + """helper function solely used in validating get_task_dict output. + Takes the output of dllm_eval.evaluator_utils.get_subtask_list and + returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are + "oversubscribed" to several disjoint groups. + """ + subtask_names = [] + for key, value in task_dict.items(): + subtask_names.extend(value) + + duplicate_tasks = { + task_name for task_name in subtask_names if subtask_names.count(task_name) > 1 + } + + # locate the potentially problematic groups that seem to 'compete' for constituent subtasks + competing_groups = [ + group + for group in task_dict.keys() + if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0 + ] + + if len(duplicate_tasks) > 0: + raise ValueError( + f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs." + ) + + +def get_task_dict( + task_name_list: Union[str, List[Union[str, Dict, Task]]], + task_manager: Optional[TaskManager] = None, +): + """Creates a dictionary of task objects from either a name of task, config, or prepared Task object. + + :param task_name_list: List[Union[str, Dict, Task]] + Name of model or LM object, see dllm_eval.models.get_model + :param task_manager: TaskManager = None + A TaskManager object that stores indexed tasks. If not set, + task_manager will load one. This should be set by the user + if there are additional paths that want to be included + via `include_path` + + :return + Dictionary of task objects + """ + + task_name_from_string_dict = {} + task_name_from_config_dict = {} + task_name_from_object_dict = {} + + if isinstance(task_name_list, str): + task_name_list = [task_name_list] + elif isinstance(task_name_list, list): + if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]): + raise TypeError( + "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match." + ) + else: + raise TypeError( + f"Expected a 'str' or 'list' but received {type(task_name_list)}." + ) + + string_task_name_list = [task for task in task_name_list if isinstance(task, str)] + others_task_name_list = [ + task for task in task_name_list if not isinstance(task, str) + ] + if len(string_task_name_list) > 0: + if task_manager is None: + task_manager = TaskManager() + + task_name_from_string_dict = task_manager.load_task_or_group( + string_task_name_list + ) + + for task_element in others_task_name_list: + if isinstance(task_element, dict): + task_name_from_config_dict = { + **task_name_from_config_dict, + **task_manager.load_config(config=task_element), + } + + elif isinstance(task_element, Task): + task_name_from_object_dict = { + **task_name_from_object_dict, + get_task_name_from_object(task_element): task_element, + } + + if not set(task_name_from_string_dict.keys()).isdisjoint( + set(task_name_from_object_dict.keys()) + ): + raise ValueError + + final_task_dict = { + **task_name_from_string_dict, + **task_name_from_config_dict, + **task_name_from_object_dict, + } + + # behavior can get odd if one tries to invoke several groups that "compete" for the same task. + # (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask + # and we'd be unsure which to use and report.) + # we explicitly check and error in this case. + _check_duplicates(get_subtask_list(final_task_dict)) + + return final_task_dict diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/README.md b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25467f905f61ef28883579f54672eab0e7c7dec6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/README.md @@ -0,0 +1,55 @@ +# AIME + +### Citation + +```text +@dataset{aime_1983_2024, + author = {Hemish Veeraboina}, + title = {AIME Problem Set 1983-2024}, + year = {2024}, + publisher = {Kaggle}, + url = {https://www.kaggle.com/datasets/hemishveeraboina/aime-problem-set-1983-2024} +} + +@dataset{aime_2024, + author = {Maxwell Jia}, + title = {AIME Problem Set 2024}, + year = {2024}, + publisher = {Huggingface}, + url = {https://huggingface.co/datasets/Maxwell-Jia/AIME_2024} +} + +@dataset{aime_2025, + author = {math-ai}, + title = {AIME Problem Set 2025}, + year = {2025}, + publisher = {Huggingface}, + url = {https://huggingface.co/datasets/math-ai/aime25} +} +``` + +### Groups, Tags, and Tasks + +#### Groups + +* `math_word_problems` + +#### Tasks + +* `aime`: `AIME 1983-2024 problems` +* `aime24`: `AIME 2024 problems` +* `aime25`: `AIME 2025 problems` + +### Checklist + +For adding novel benchmarks/datasets to the library: + +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + +If other tasks on this dataset are already supported: + +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime.yaml b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9a6cced3adcc8f8918e55c49fbc92eeda2b7623 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime.yaml @@ -0,0 +1,29 @@ +tag: + - math_word_problems +task: aime +dataset_path: gneubig/aime-1983-2024 +# dataset_name: null +output_type: generate_until +training_split: train +fewshot_split: train +test_split: train +doc_to_text: "Question: {{Question}}\nAnswer:" +doc_to_target: "{{Answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Question:" + - "" + - "<|im_end|>" + - "<|eot_id|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 32768 +repeats: 1 +num_fewshot: 0 +metadata: + version: 0.0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime24.yaml b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime24.yaml new file mode 100644 index 0000000000000000000000000000000000000000..714596912615b5c16d4708e21f0eb56b33959754 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime24.yaml @@ -0,0 +1,29 @@ +tag: + - math_word_problems +task: aime24 +dataset_path: Maxwell-Jia/AIME_2024 +# dataset_name: null +output_type: generate_until +training_split: train +fewshot_split: train +test_split: train +doc_to_text: "Question: {{Problem}}\nAnswer:" +doc_to_target: "{{Answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Question:" + - "" + - "<|im_end|>" + - "<|eot_id|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 32768 +repeats: 1 +num_fewshot: 0 +metadata: + version: 0.0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime25.yaml b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime25.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ef64005863674f7afc5c76b8cdff22d224ae2da --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/aime25.yaml @@ -0,0 +1,29 @@ +tag: + - math_word_problems +task: aime25 +dataset_path: math-ai/aime25 +# dataset_name: null +output_type: generate_until +training_split: test +fewshot_split: test +test_split: test +doc_to_text: "Question: {{problem}}\nAnswer:" +doc_to_target: "{{answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Question:" + - "" + - "<|im_end|>" + - "<|eot_id|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 32768 +repeats: 1 +num_fewshot: 0 +metadata: + version: 0.0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f668c23bc18d646c16390302ad24cc3ced1aa3b4 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/aime/utils.py @@ -0,0 +1,231 @@ +import re +from typing import Dict, List + + +def process_results(doc: dict, results: List[str]) -> Dict[str, int]: + retval = 0 + response = results[0] + + # Try to extract answer from $...$ format first + indices = [pos for pos, char in enumerate(response) if char == "$"] + if len(indices) <= 1: + answer = response + else: + answer = response[indices[0] + 1 : indices[-1]] + + # Extract from \\boxed{} if present + boxed_answer = last_boxed_only_string(response) + if boxed_answer is not None: + try: + boxed_content = remove_boxed(boxed_answer) + if boxed_content is not None: + answer = boxed_content + except (AssertionError, IndexError): + pass + + # Check if answer matches target + answer_key = next(k for k in doc.keys() if k.lower() == "answer") + target = str(doc[answer_key]) + if is_equiv(answer, target): + retval = 1 + + return {"exact_match": retval} + + +# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print("WARNING: Both None") + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def remove_boxed(s): + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except AssertionError: + return string + + +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = fix_a_slash_b(string) + + return string diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c56206923cf19bac4ec07233c6b0b17ac0460ad --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/gsm8k/gsm8k.yaml @@ -0,0 +1,15 @@ +task: gsm8k +dataset_path: openai/gsm8k +dataset_name: main +output_type: generate_until +training_split: train +fewshot_split: train +test_split: test +doc_to_text: !function utils.gsm_prompt +doc_to_target: "{{answer.split('####')[-1].strip()}}" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/gsm8k/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/gsm8k/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ceaa3d2ab7af89f27e69b470a2f6787f6133519 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/gsm8k/utils.py @@ -0,0 +1,13 @@ +def gsm_prompt(doc): + system_prompt = ( + "You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n" + "Respond in the following format:\n" + "\n" + "Your reasoning here\n" + "\n" + "\n" + "\\boxed{...}\n" + "" + ) + prompt = f"{system_prompt}\n\n{doc['question']}\n\n" + return prompt diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..024d38f0da160e853cd8c3123104a4485677c0fd --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/humaneval/humaneval.yaml @@ -0,0 +1,13 @@ +task: humaneval +dataset_path: openai/openai_humaneval +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n{{prompt}}\n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n\nYour reasoning here\n\n\n```python\nThe complete implementation of the {{entry_point}} function\n```\n" +doc_to_target: "{{test}}\ncheck({{entry_point}})" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/humaneval/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/humaneval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11bac61cfa12fad57aacfed28b55bee467cf23e4 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/humaneval/utils.py @@ -0,0 +1,43 @@ +import evaluate as hf_evaluate + + +try: + compute_ = hf_evaluate.load("code_eval") + test_cases = ["assert add(2, 3)==5"] + candidates = [["def add(a,b): return a*b"]] + results = compute_.compute(references=test_cases, predictions=candidates, k=[1]) +except Exception as e: + raise e + + +def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None): + global compute_ + assert k is not None + if isinstance(k, int): + k = [k] + res = compute_.compute( + references=references, + predictions=predictions, + k=k + ) + return res[0] + + +def clean_response_string(r: str) -> str: + cleaned_text = r if r.rfind("```python") == -1 else r[r.rfind("```python"):] + cleaned_text = cleaned_text if cleaned_text.rfind("```") == -1 else cleaned_text[: cleaned_text.rfind("```")] + cleaned_text = cleaned_text if cleaned_text.rfind("if __name__ == \"__main__\":") == -1 else cleaned_text[: cleaned_text.rfind("if __name__ == \"__main__\":")] + return cleaned_text + + +def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)] + + +def build_predictions( + resps: list[list[str]], docs: list[dict] +) -> list[list[str]]: + return [ + [clean_response_string(r) for r in resp] + for resp, doc in zip(resps, docs) + ] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/math500/math500.yaml b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/math500/math500.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1fe2f7a38417fe863c1301953be514b618054707 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/math500/math500.yaml @@ -0,0 +1,12 @@ +task: math500 +dataset_path: HuggingFaceH4/MATH-500 +output_type: generate_until +test_split: test +doc_to_text: !function utils.math500_prompt +doc_to_target: "{{answer}}" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/math500/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/math500/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0585298c29c8b5c12ebeaa01dfff572267db601 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/math500/utils.py @@ -0,0 +1,14 @@ +def math500_prompt(doc): + system_prompt = ( + "You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n" + "Respond in the following format:\n" + "\n" + "Your reasoning here\n" + "\n" + "\n" + "\\boxed{...}\n" + "" + ) + + prompt = f"{system_prompt}\n\n{doc['problem']}\n\n" + return prompt diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5b9755ad30669e2335bd374ba5f53db0572630f --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/mbpp/mbpp.yaml @@ -0,0 +1,14 @@ +task: mbpp +dataset_path: google-research-datasets/mbpp +dataset_name: full +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "\n{{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}} \n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n\nYour reasoning here\n\n\n```python\nThe complete implementation of the function\n```\n" +doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}" +target_delimiter: "" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/mbpp/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/mbpp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..234fc7ed5de047e556dea2ff77d02a232c8f3e6e --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/tasks/mbpp/utils.py @@ -0,0 +1,79 @@ +import re +from typing import Union + +import evaluate as hf_evaluate + + +try: + pass_at_k = hf_evaluate.load("code_eval") + + # run simple test to check code execution is enabled before model generation + test_cases = ["assert add(2, 3)==5"] + candidates = [["def add(a,b): return a*b"]] + results = pass_at_k.compute(references=test_cases, predictions=candidates, k=[1]) +except Exception as e: + raise e + + +def pass_at_1( + references: Union[str, list[str]], predictions: Union[str, list[list[str]]] +) -> float: + if isinstance(references, str): + references = [references] + if isinstance(predictions[0], str): + predictions = [[p] for p in predictions] + return pass_at_k.compute( + references=references, + predictions=predictions, + k=[1], + num_workers=48 + )[0]["pass@1"] + + +def extract_code_blocks(text: str) -> str: + text = re.sub(r"\[DONE\]", "", text) + text = re.sub(r"<\|eot_id\|>", "", text) + text = re.sub(r"<\|endoftext\|>", "", text) + return text + + +def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + return [[extract_code_blocks(r) for r in resp] for resp in resps] + + +def list_fewshot_samples(): + return [ + { + "task_id": 2, + "text": "Write a function to find the similar elements from the given two tuple lists.", + "code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ", + "test_list": [ + "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)", + "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)", + "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)", + ], + "is_fewshot": True, + }, + { + "task_id": 3, + "text": "Write a python function to identify non-prime numbers.", + "code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result", + "test_list": [ + "assert is_not_prime(2) == False", + "assert is_not_prime(10) == True", + "assert is_not_prime(35) == True", + ], + "is_fewshot": True, + }, + { + "task_id": 4, + "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.", + "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums", + "test_list": [ + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]", + ], + "is_fewshot": True, + }, + ] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d75d370a30862ba13dc3905f39031f631d70e8fd --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/dllm_eval/utils.py @@ -0,0 +1,552 @@ +import collections +import fnmatch +import functools +import hashlib +import importlib.util +import inspect +import json +import logging +import os +import re +from dataclasses import asdict, is_dataclass +from itertools import islice +from pathlib import Path +from typing import Any, Callable, Generator, List, Optional, Tuple + +import numpy as np +import yaml +from jinja2 import BaseLoader, Environment, StrictUndefined + + +SPACING = " " * 47 + +HIGHER_IS_BETTER_SYMBOLS = { + True: "↑", + False: "↓", +} + + +def setup_logging(verbosity=logging.INFO): + # Configure the root logger + class CustomFormatter(logging.Formatter): + def format(self, record): + if record.name.startswith("dllm_eval."): + record.name = record.name[len("dllm_eval.") :] + return super().format(record) + + formatter = CustomFormatter( + "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d:%H:%M:%S", + ) + + log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity + + level_map = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + + log_level = level_map.get(str(log_level).upper(), logging.INFO) + + if not logging.root.handlers: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.addHandler(handler) + root_logger.setLevel(log_level) + + if log_level == logging.DEBUG: + third_party_loggers = ["urllib3", "filelock", "fsspec"] + for logger_name in third_party_loggers: + logging.getLogger(logger_name).setLevel(logging.INFO) + else: + logging.getLogger().setLevel(log_level) + + +def hash_string(string: str) -> str: + return hashlib.sha256(string.encode("utf-8")).hexdigest() + + +def escaped_split(text, sep_char, maxsplit=-1): + """Split text into a list on occurrences of the given separation + character `sep_char`. The separation character may be escaped by a + backslash to avoid splitting at that location. + + The separation character must be a string of size 1. + + If `maxsplit` is given, at most `maxsplit` splits are done (thus, + the list will have at most `maxsplit + 1` elements). If `maxsplit` + is not specified or less than 0, then there is no limit on the + number of splits (all possible splits are made). + """ + assert len(sep_char) == 1, ( + "separation string must be a single character for escaped splitting" + ) + + if maxsplit == 0: + return text + maxsplit = max(0, maxsplit) + + return re.split(r"(? dict: + """ + Parses something like + args1=val1,arg2=val2 + Into a dictionary + """ + if args_string is None: + return {} + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = { + kv[0]: handle_arg_string("=".join(kv[1:])) + for kv in [arg.split("=") for arg in arg_list] + } + return args_dict + + +def join_iters(iters): + for iter in iters: + yield from iter + + +def group(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + + return list(res.values()) + + +# Returns a list containing all values of the source_list that +# match at least one of the patterns +def pattern_match(patterns, source_list): + if isinstance(patterns, str): + patterns = [patterns] + + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return sorted(list(task_names)) + + +def softmax(x) -> np.ndarray: + """Compute softmax values for each sets of scores in x.""" + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + +def general_detokenize(string) -> str: + string = string.replace(" n't", "n't") + string = string.replace(" )", ")") + string = string.replace("( ", "(") + string = string.replace('" ', '"') + string = string.replace(' "', '"') + string = re.sub(r" (['.,])", r"\1", string) + return string + + +def get_file_task_name(filename: str) -> str: + """ + Given the sample results filenames, extracts and returns the task name. + """ + return filename[filename.find("_") + 1 : filename.rfind("_")] + + +def get_file_datetime(filename: str) -> str: + """ + Given the results and sample results filenames, extracts and returns the datetime. + """ + return filename[filename.rfind("_") + 1 :].replace(".jsonl", "") + + +def sanitize_model_name(model_name: str) -> str: + """ + Given the model name, returns a sanitized version of it. + """ + return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) + + +def sanitize_task_name(task_name: str) -> str: + """ + Given the task name, returns a sanitized version of it. + """ + return re.sub(r"\W", "_", task_name) + + +def get_latest_filename(filenames: List[str]) -> str: + """ + Given a list of filenames, returns the filename with the latest datetime. + """ + return max(filenames, key=lambda f: get_file_datetime(f)) + + +def get_results_filenames(filenames: List[str]) -> List[str]: + """ + Extracts filenames that correspond to aggregated results. + """ + return [f for f in filenames if "/results_" in f and ".json" in f] + + +def get_sample_results_filenames(filenames: List[str]) -> List[str]: + """ + Extracts filenames that correspond to sample results. + """ + return [f for f in filenames if "/samples_" in f and ".json" in f] + + +def get_rolling_token_windows( + token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int +) -> Generator[Tuple[List[int], List[int]], None, None]: + """ + - context_len allows for a rolling window context, allowing each prediction window to potentially + condition on some context + + :param token_list: list + List of tokens to be PREDICTED + :param max_seq_len: int + max_seq_len of model (or max_seq_len we want to use) + :param context_len: int + Amount of desired token context for prediction. Needs to be at least 1. + :param prefix_token: token + Dummy token like so the first token has something to condition on + :return: generator + Generator of tuples + (input_tokens, pred_tokens) + Note: Score only the last len(pred_tokens) logits of the LM + """ + assert 1 <= context_len <= max_seq_len + if not token_list: + return + # +1 offset, going from input->preds + pred_len = max_seq_len - context_len + 1 + predicted = 0 + + # Special handling for first window: predict all tokens + first_seq_len = min(max_seq_len, len(token_list)) + yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len] + predicted += first_seq_len + + while predicted < len(token_list): + window_pred_len = min(len(token_list) - predicted, pred_len) + window_end = predicted + window_pred_len + + yield ( + token_list[window_end - max_seq_len - 1 : window_end - 1], + token_list[window_end - window_pred_len : window_end], + ) + predicted += window_pred_len + + +def make_disjoint_window( + pair: Tuple[List[int], List[int]], +) -> Tuple[List[int], List[int]]: + """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" + a, b = pair + return a[: len(a) - (len(b) - 1)], b + + +class EnhancedJSONEncoder(json.JSONEncoder): + """ + Provides a proper json encoding for the loggers and trackers json dumps. + Notably manages the json encoding of dataclasses. + """ + + def default(self, o): + if is_dataclass(o): + return asdict(o) + return super().default(o) + + +class Reorderer: + def __init__(self, arr: List[Any], fn: Callable) -> None: + """Reorder an array according to some function + + Args: + arr (List[Any]): The initial array + fn (Callable[[Any], Any]): A function to determine the priority of elements + """ + self.size = len(arr) + arr = list(enumerate(arr)) + arr = group(arr, lambda x: fn(x[1])) + # arr = [([y[0] for y in x], x[0][1]) for x in arr] + # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this + arr = [([y[0]], x[0][1]) for x in arr for y in x] + arr.sort(key=lambda x: fn(x[1])) + + self.arr = arr + + def get_reordered(self): + """Gets the reordered array + + Returns: + List[Any]: The reordered array + """ + return [x[1] for x in self.arr] + + def get_original(self, newarr): + """Restores the original order of a new array based on the old array's order + + Args: + newarr (List[Any]): The array to be restored + + Returns: + List[Any]: The array restored to the original order + """ + res = [None] * self.size + cov = [False] * self.size + + for (inds, _), v in zip(self.arr, newarr): + for ind in inds: + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + + +def make_table(result_dict, column: str = "results", sort_results: bool = False): + """Generate table of results.""" + from pytablewriter import LatexTableWriter, MarkdownTableWriter + + if column == "results": + column_name = "Tasks" + elif column == "groups": + column_name = "Groups" + + all_headers = [ + column_name, + "Version", + "Filter", + "n-shot", + "Metric", + "", + "Value", + "", + "Stderr", + ] + + md_writer = MarkdownTableWriter() + latex_writer = LatexTableWriter() + md_writer.headers = all_headers + latex_writer.headers = all_headers + + values = [] + + keys = result_dict[column].keys() + if sort_results: + # sort entries alphabetically by task or group name. + # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu. + # sorting here would mess that up + keys = sorted(keys) + for k in keys: + dic = result_dict[column][k] + version = result_dict["versions"].get(k, " N/A") + n = str(result_dict.get("n-shot", " ").get(k, " ")) + higher_is_better = result_dict.get("higher_is_better", {}).get(k, {}) + + if "alias" in dic: + k = dic.pop("alias") + + metric_items = dic.items() + metric_items = sorted(metric_items) + + for (mf), v in metric_items: + m, _, f = mf.partition(",") + if m.endswith("_stderr"): + continue + + hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "") + + v = "%.4f" % v if isinstance(v, float) else v + + if m + "_stderr" + "," + f in dic: + se = dic[m + "_stderr" + "," + f] + se = " N/A" if se == "N/A" else "%.4f" % se + values.append([k, version, f, n, m, hib, v, "±", se]) + else: + values.append([k, version, f, n, m, hib, v, "", ""]) + k = "" + version = "" + md_writer.value_matrix = values + latex_writer.value_matrix = values + + # todo: make latex table look good + # print(latex_writer.dumps()) + + return md_writer.dumps() + + +def positional_deprecated(fn): + """ + A decorator to nudge users into passing only keyword args (`kwargs`) to the + wrapped function, `fn`. + """ + + @functools.wraps(fn) + def _wrapper(*args, **kwargs): + if len(args) != 1 if inspect.ismethod(fn) else 0: + print( + f"WARNING: using {fn.__name__} with positional arguments is " + "deprecated and will be disallowed in a future version of " + "lm-evaluation-harness!" + ) + return fn(*args, **kwargs) + + return _wrapper + + +def ignore_constructor(loader, node): + return node + + +def import_function(loader: yaml.Loader, node, yaml_path: Path): + function_name = loader.construct_scalar(node) + + *module_name, function_name = function_name.split(".") + if isinstance(module_name, list): + module_name = ".".join(module_name) + module_path = yaml_path.parent / f"{module_name}.py" + + spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix()) + + if spec is None: + raise ImportError(f"Could not import module {module_name} from {module_path}.") + module = importlib.util.module_from_spec(spec) + + if spec.loader is None: + raise ImportError(f"Module loader is None, {module_name} from {module_path}.") + spec.loader.exec_module(module) + + function = getattr(module, function_name) + return function + + +def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): + if mode == "simple": + constructor_fn = ignore_constructor + elif mode == "full": + if yaml_path is None: + raise ValueError("yaml_path must be provided if mode is 'full'.") + # Attach yaml_path to the import function so that it can be used later + constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path)) + + loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader + # Add the import_function constructor to the YAML loader + yaml.add_constructor("!function", constructor_fn, Loader=loader) + if yaml_config is None: + with open(yaml_path, "rb") as file: + yaml_config = yaml.load(file, Loader=loader) + + if yaml_dir is None: + yaml_dir = os.path.dirname(yaml_path) + + assert yaml_dir is not None + + if "include" in yaml_config: + include_path = yaml_config["include"] + del yaml_config["include"] + + if isinstance(include_path, str): + include_path = [include_path] + + # Load from the last one first + include_path.reverse() + final_yaml_config = {} + for path in include_path: + # Assumes that path is a full path. + # If not found, assume the included yaml + # is in the same dir as the original yaml + if not os.path.isfile(path): + path = os.path.join(yaml_dir, path) + + try: + included_yaml_config = load_yaml_config(yaml_path=path, mode=mode) + final_yaml_config.update(included_yaml_config) + except Exception as ex: + # If failed to load, ignore + raise ex + + final_yaml_config.update(yaml_config) + return final_yaml_config + return yaml_config + + +def regex_replace(string, pattern, repl, count: int = 0): + """Implements the `re.sub` function as a custom Jinja filter.""" + return re.sub(pattern, repl, string, count=count) + + +env = Environment( + loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True +) +env.filters["regex_replace"] = regex_replace + + +def apply_template(template: str, doc: dict) -> str: + rtemplate = env.from_string(template) + return rtemplate.render(**doc) + + +def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): + """ + Method for creating a (potentially) sliced and limited + iterator from a raw document iterator. Used for splitting data + among ranks in multigpu setting or only pulling a sample of documents + """ + return islice(raw_iterator, rank, limit, world_size) + + +def weighted_f1_score(items): + from sklearn.metrics import f1_score + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + fscore = f1_score(golds, preds, average="weighted") + return fscore diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/evaluation_script.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/evaluation_script.py new file mode 100644 index 0000000000000000000000000000000000000000..0c90bd0c9c7ebd1f15b77158670a3858f5532468 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/evaluation_script.py @@ -0,0 +1,21 @@ +import os +import torch +import random +import numpy as np +from dllm_eval.__main__ import cli_evaluate + + +def set_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +if __name__ == "__main__": + os.environ["HF_ALLOW_CODE_EVAL"] = "1" + os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" + set_seed(42) + cli_evaluate() \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/gsm8k_all.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/gsm8k_all.py new file mode 100644 index 0000000000000000000000000000000000000000..7133a935166c211bf8a8f2e535ae7e1bd54061b6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/gsm8k_all.py @@ -0,0 +1,286 @@ +import json +import re +import os +import math +import argparse +from collections import Counter + +RES_PATH = "" + +def last_boxed_only_string(string): + if not string: return None + idx = max(string.rfind("\\boxed"), string.rfind("\\fbox")) + if idx < 0: return None + + if "\\boxed " in string[idx:idx+8] and "{" not in string[idx:idx+8]: + return "\\boxed " + string[idx:].split("\\boxed ")[-1].split("$")[0].strip() + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + return string[idx : right_brace_idx + 1] if right_brace_idx else None + +def remove_boxed(s): + if not s: return None + if "\\boxed " in s: return s[len("\\boxed ") :] + if "\\boxed{" in s and s.endswith("}"): return s[len("\\boxed{") : -1] + if "\\fbox{" in s and s.endswith("}"): return s[len("\\fbox{") : -1] + return s + +def strip_string(string): + if string is None: return "" + string = str(string).strip() + while re.search(r"(\d),(\d{3})", string): + string = re.sub(r"(\d),(\d{3})", r"\1\2", string) + + string = string.replace("\n", "").replace("\\!", "") + string = string.replace("tfrac", "frac").replace("dfrac", "frac") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("^{\\circ}", "").replace("^\\circ", "") + string = string.replace("\\$", "").replace("\\%", "").replace("\%", "") + + if "=" in string and len(string.split("=")[0]) <= 5: + string = string.split("=")[1].strip() + + string = string.replace(" ", "") + string = string.rstrip(".") + return string + +def normalize_to_number(s): + s_clean = strip_string(s) + try: + if '/' in s_clean and len(s_clean.split('/')) == 2: + parts = s_clean.split('/') + return float(parts[0]) / float(parts[1]) + return float(s_clean) + except: + return s_clean + +def extract_answer_gsm8k_debug(text): + if not text: return "", "empty" + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip() + + boxed = last_boxed_only_string(text) + if boxed: + ans = remove_boxed(boxed) + if ans: + return strip_string(ans), "boxed" + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + return strip_string(tag_match.group(1)), "xml_tag" + + last_text = text[-200:] if len(text) > 200 else text + marker = "the answer is" + if marker in last_text.lower(): + idx = last_text.lower().rfind(marker) + after = last_text[idx + len(marker):].strip() + after = re.split(r"[.\n]", after)[0] + after = after.replace(":", "").replace("$", "").strip() + return strip_string(after), "text_marker" + + tail = text[-50:] + nums = re.findall(r"(?>> 正在评测: {file_path}") + detailed_results = [] + + correct_voted_count = 0 + correct_any_count = 0 + total_count = 0 + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + try: + item = json.loads(line) + except: + continue + + doc = item.get("doc", {}) + ground_truth = extract_gold_gsm8k(str(item.get("target", ""))) + + total_nfe_item = item.get("nfe", 0) + nfe_list.append(total_nfe_item) + svf_list.append(item.get("svf_calls", 0)) + + trajectories = item.get("all_trajectories", []) + if not trajectories: + resps = item.get("resps", []) + for r in resps: + text = r[0] if isinstance(r, list) else r + trajectories.append({"resp": text, "score": 0.0}) + + parsed_paths = [] + traj_debug_info = [] + + for idx, traj in enumerate(trajectories): + raw_text = traj.get("resp", "") + score = traj.get("score", 0.0) + + extracted, method = extract_answer_gsm8k_debug(raw_text) + + is_correct_single = False + if extracted: + is_correct_single = is_equiv(extracted, ground_truth) + val_key = normalize_to_number(extracted) + + parsed_paths.append({ + "original_text": extracted, + "val_key": val_key, + "score": score, + "method": method + }) + + traj_debug_info.append({ + "id": idx, + "extracted": extracted, + "score": score, + "is_correct": is_correct_single, + "extract_method": method + }) + + if not parsed_paths: + detailed_results.append({ + "question": doc.get("question", "N/A"), + "final_voted_answer": "", + "ground_truth": ground_truth, + "is_voted_correct": False, + "trajectory_details": traj_debug_info, + "nfe": total_nfe_item, + "svf_calls": item.get("svf_calls", 0) + }) + total_count += 1 + continue + + has_correct = any(p['score'] > -999 and is_equiv(p['original_text'], ground_truth) for p in parsed_paths) + if has_correct: + correct_any_count += 1 + + parsed_paths.sort(key=lambda x: x['score'], reverse=True) + top_k_count = max(1, int(len(parsed_paths) * 0.6)) + voting_candidates = parsed_paths[:top_k_count] + + ans_stats = {} + for p in voting_candidates: + k = p['val_key'] + if k not in ans_stats: + ans_stats[k] = { + "total_weight": 0.0, + "count": 0, + "max_score": -float('inf'), + "best_repr": p['original_text'] + } + + try: + weight = math.exp(p['score']) + except OverflowError: + weight = float('inf') + + ans_stats[k]["total_weight"] += weight + ans_stats[k]["count"] += 1 + if p['score'] > ans_stats[k]["max_score"]: + ans_stats[k]["max_score"] = p['score'] + ans_stats[k]["best_repr"] = p['original_text'] + + sorted_answers = sorted( + ans_stats.items(), + key=lambda x: (x[1]["total_weight"], x[1]["max_score"]), + reverse=True + ) + + best_pred = str(sorted_answers[0][1]["best_repr"]) + is_voted_correct = is_equiv(best_pred, ground_truth) + if is_voted_correct: + correct_voted_count += 1 + + vote_summary = [] + for val, info in sorted_answers: + vote_summary.append({ + "answer": str(val), + "count": info["count"], + "total_weight": info["total_weight"], + "is_correct": is_equiv(str(val), ground_truth) + }) + + total_count += 1 + + detailed_results.append({ + "question": doc.get("question", "N/A"), + "final_voted_answer": best_pred, + "ground_truth": ground_truth, + "is_voted_correct": is_voted_correct, + "vote_stats": vote_summary, + "trajectory_details": traj_debug_info, + "nfe": total_nfe_item, + "svf_calls": item.get("svf_calls", 0) + }) + + accuracy = (correct_voted_count / total_count * 100) if total_count > 0 else 0 + pass_at_k = (correct_any_count / total_count * 100) if total_count > 0 else 0 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + final_report = { + "summary": { + "accuracy": f"{accuracy:.2f}%", + "correct_voted": correct_voted_count, + "total": total_count, + "nfe": avg_nfe, + "svf_calls": avg_svf + }, + "details": detailed_results + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/humaneval_all.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/humaneval_all.py new file mode 100644 index 0000000000000000000000000000000000000000..842a77c8938d7de95247e6f153e42b00625dea99 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/humaneval_all.py @@ -0,0 +1,183 @@ +import os +import sys +import json +import ast +import traceback +import glob +import math +import argparse +from typing import Dict, List, Optional, Set, Tuple +from collections import Counter +import evaluate as hf_evaluate +import re + +RES_PATH = "" + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" + +def extract_python_code(text: str) -> str: + if not text: return "" + + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "") + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + text = tag_match.group(1) + + if "```python" in text: + content = text.split("```python")[-1] + if "```" in content: + return content.split("```")[0].strip() + return content.strip() + elif "```" in text: + content = text.split("```")[-1] + if "```" in content: + return content.split("```")[0].strip() + return content.strip() + + lines = text.split('\n') + cleaned_lines = [] + stop_words = ["Explanation:", "Example:", "Test Case:", "Output:"] + for line in lines: + if any(sw in line for sw in stop_words): + break + cleaned_lines.append(line) + + return "\n".join(cleaned_lines).strip() + +def normalize_code_for_voting(code: str) -> str: + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (node.body and isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)): + node.body.pop(0) + return ast.unparse(tree).strip() + except: + return re.sub(r"\s+", "", code) + +def sanitize(prompt: str, completion: str, entrypoint: str) -> str: + if f"def {entrypoint}" in completion: + return completion + return prompt + "\n" + completion + +def run_evaluation(target_path): + if os.path.isdir(target_path): + jsonl_files = glob.glob(os.path.join(target_path, "**/*.jsonl"), recursive=True) + else: + jsonl_files = [target_path] + + if not jsonl_files: + print(f"未在路径 {target_path} 下找到任何 .jsonl 文件") + return + + print(f"共找到 {len(jsonl_files)} 个评测任务") + code_eval = hf_evaluate.load("code_eval") + + for file_path in jsonl_files: + print(f"\n>>> 正在评测: {file_path}") + all_predictions = [] + all_references = [] + detailed_results = [] + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + if not lines: continue + + for line in lines: + if not line.strip(): continue + item = json.loads(line) + doc = item.get("doc", {}) + prompt = doc.get("prompt", "") + entry_point = doc.get("entry_point", "") + reference = doc.get("test", "") + + current_nfe = item.get("nfe", 0) + nfe_list.append(current_nfe) + svf_list.append(item.get("svf_calls", 0)) + + resps = item.get("resps", []) + candidate_stats = {} + + for r in resps: + raw_text = r[0] if isinstance(r, list) else r + completion = extract_python_code(raw_text) + full_code = sanitize(prompt, completion, entry_point) + + try: + ast.parse(full_code) + is_valid = True + except: + is_valid = False + + logic_norm = normalize_code_for_voting(full_code) + if not logic_norm: continue + + if logic_norm not in candidate_stats: + candidate_stats[logic_norm] = {"count": 0, "valid": is_valid, "code": full_code} + candidate_stats[logic_norm]["count"] += 1 + + if not candidate_stats: + voted_code = prompt + else: + sorted_logics = sorted( + candidate_stats.keys(), + key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["count"]), + reverse=True + ) + voted_code = candidate_stats[sorted_logics[0]]["code"] + + all_predictions.append([voted_code]) + all_references.append(reference) + detailed_results.append({ + "task_id": doc.get("task_id", doc.get("name", "N/A")), + "voted_code": voted_code, + "nfe": current_nfe, + "svf_calls": item.get("svf_calls", 0), + "candidates_count": len(candidate_stats) + }) + + if not all_predictions: continue + + print(f"正在执行代码测试 (共 {len(all_predictions)} 题)...") + pass_at_k, exec_results = code_eval.compute( + references=all_references, + predictions=all_predictions, + k=[1], + num_workers=4 + ) + + accuracy = pass_at_k.get("pass@1", 0.0) * 100 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- 结果: Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + for i, detail in enumerate(detailed_results): + res_list = exec_results.get(i, []) + detail["is_correct"] = res_list[0][1]["passed"] if res_list else False + + final_report = { + "summary": { + "accuracy": f"{accuracy:.2f}%", + "nfe": avg_nfe, + "svf_calls": avg_svf + }, + "details": detailed_results + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + print(f"报告已保存至: {output_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/math500_all.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/math500_all.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7d8671623f3e848ebca1c5836928185179f5fb --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/math500_all.py @@ -0,0 +1,213 @@ +import json +import re +import os +import math +import argparse +from collections import Counter + +RES_PATH = "" + +def extract_answer(text): + if not text: + return "", False + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip() + + boxed_pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" + all_boxes = re.findall(boxed_pattern, text) + if all_boxes: + return all_boxes[-1], True + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + return tag_match.group(1).strip(), True + + marker = "the answer is" + if marker in text.lower(): + pos = text.lower().rfind(marker) + after_text = text[pos + len(marker):].strip() + after_text = re.sub(r"^[:\s]+", "", after_text) + return after_text.split('\n')[0].split('$')[0].strip(), True + + tail = text[-50:].strip() + nums = re.findall(r"(-?\d+[\./\d]*|\\sqrt\{\d+\}|\(-?\d+.*?\))", tail) + if nums: + return nums[-1], False + return "", False + +def normalize_math(string): + if not string: return "" + string = str(string).lower().strip() + + string = string.replace("", "").replace("", "").replace("", "") + string = string.replace("...", "").replace("cannot be determined", "") + + string = re.sub(r"([a-z]+|\\theta|\\alpha|\\pi)\s*=\s*", "", string) + string = re.sub(r"\\text\{([^}]*)\}", r"\1", string) + string = re.sub(r"\\(mathbf|mathrm|bold|unit|mbox|operatorname|mathrm)\{([^}]*)\}", r"\2", string) + string = re.sub(r"\\(d|t)?frac\{([^{}]*)\}\{([^{}]*)\}", r"\2/\3", string) + string = string.replace("\\!", "").replace("\\ ", "").replace("{", "").replace("}", "") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("\\$", "").replace("$", "").replace("\\%", "").replace("%", "") + + units_pattern = r"(units?|cm\^2|cm|inches|inch|square|degrees?|radians?|miles?|per|hour|cents?)" + string = re.sub(units_pattern, "", string) + string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("°", "").replace("\\degree", "") + string = string.replace("\\pi", "pi") + string = re.sub(r"(\d),(\d{3})", r"\1\2", string) + string = string.rstrip(".:,; ").replace(" ", "") + + if "=" in string: + string = string.split("=")[-1] + + return string + +def is_equiv(pred, gold): + if not pred: return False + p, g = normalize_math(pred), normalize_math(gold) + if p == g: return True + + if "=" in pred: + if normalize_math(pred.split("=")[-1]) == g: + return True + + try: + def to_float(s): + if '/' in s and s.count('/') == 1: + parts = s.split('/') + return float(parts[0]) / float(parts[1]) + if '_' in s: s = s.split('_')[0] + return float(s) + return math.isclose(to_float(p), to_float(g), rel_tol=1e-4) + except: + p_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", p) + g_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", g) + return p_fuzzy == g_fuzzy if p_fuzzy else False + +def run_evaluation(target_path): + jsonl_files = [] + if os.path.isdir(target_path): + for root, dirs, files in os.walk(target_path): + for file in files: + if file.endswith(".jsonl") and not file.startswith("eval_voted_"): + jsonl_files.append(os.path.join(root, file)) + else: + jsonl_files = [target_path] + + for file_path in jsonl_files: + print(f">>> 正在评测: {file_path}") + detailed_results = [] + + voted_correct_count = 0 + pass_at_k_count = 0 + total_count = 0 + + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + try: + item = json.loads(line) + except: + continue + + doc = item.get("doc", {}) + ground_truth = str(item.get("target", doc.get("answer", ""))) + + current_nfe = item.get("nfe", 0) + nfe_list.append(current_nfe) + current_svf = item.get("svf_calls", 0) + svf_list.append(current_svf) + + ans_stats = {} + trajectories = item.get("all_trajectories", []) + + has_correct_trajectory = False + + for traj in trajectories: + raw_text = traj.get("resp", "") + score = traj.get("score", 0) + + extracted, _ = extract_answer(raw_text) + if not extracted: continue + + if is_equiv(extracted, ground_truth): + has_correct_trajectory = True + + norm = normalize_math(extracted) + if norm not in ans_stats: + ans_stats[norm] = { + "count": 0, + "max_score": -float('inf'), + "total_weight": 0.0, + "original": extracted + } + + ans_stats[norm]["count"] += 1 + if score > ans_stats[norm]["max_score"]: + ans_stats[norm]["max_score"] = score + + try: + weight = math.exp(score) + except OverflowError: + weight = float('inf') + ans_stats[norm]["total_weight"] += weight + + if has_correct_trajectory: + pass_at_k_count += 1 + + if not ans_stats: + best_pred = "" + else: + sorted_norms = sorted( + ans_stats.keys(), + key=lambda x: (ans_stats[x]["total_weight"], ans_stats[x]["max_score"], ans_stats[x]["count"]), + reverse=True + ) + best_norm = sorted_norms[0] + best_pred = ans_stats[best_norm]["original"] + + is_voted_correct = False + if best_pred and is_equiv(best_pred, ground_truth): + voted_correct_count += 1 + is_voted_correct = True + + total_count += 1 + + detailed_results.append({ + "question": doc.get("problem", "N/A"), + "final_voted_answer": best_pred, + "ground_truth": ground_truth, + "is_voted_correct": is_voted_correct, + "nfe": current_nfe, + "svf_calls": current_svf + }) + + pass_at_1_accuracy = (voted_correct_count / total_count * 100) if total_count > 0 else 0 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- Accuracy: {pass_at_1_accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + final_report = { + "summary": { + "accuracy": f"{pass_at_1_accuracy:.2f}%", + "correct_voted_count": voted_correct_count, + "total": total_count, + "nfe": avg_nfe, + "svf_calls": avg_svf + }, + "details": detailed_results + } + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/mbpp_all.py b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/mbpp_all.py new file mode 100644 index 0000000000000000000000000000000000000000..7dce200195e0e46d44c74008ade6d22492fc3267 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/metrics/mbpp_all.py @@ -0,0 +1,194 @@ +import os +import json +import ast +import glob +import re +import argparse +from typing import Dict, List, Optional, Set, Tuple +import evaluate as hf_evaluate + +RES_PATH = "" + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +def extract_python_code(text: str) -> str: + if not text: return "" + + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "") + + tag_matches = re.findall(r"(.*?)", text, re.DOTALL) + if tag_matches: + for block in tag_matches: + if "def " in block: + text = block + break + else: + text = tag_matches[0] + + if "```python" in text: + blocks = text.split("```python") + for b in blocks[1:]: + code = b.split("```")[0].strip() + if "def " in code: return code + elif "```" in text: + blocks = text.split("```") + for b in blocks[1:]: + code = b.strip() + if "def " in code: return code + + lines = text.split('\n') + cleaned_lines = [] + stop_words = ["Explanation:", "Example:", "Test Case:", "Output:", "Reasoning:"] + for line in lines: + if any(sw in line for sw in stop_words): break + cleaned_lines.append(line) + + return "\n".join(cleaned_lines).strip() + +def normalize_code_for_voting(code: str) -> str: + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (node.body and isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)): + node.body.pop(0) + return ast.unparse(tree).strip() + except: + return re.sub(r"\s+", "", code) + +def run_evaluation(target_path): + target_path = os.path.abspath(target_path) + + if os.path.isdir(target_path): + search_pattern = os.path.join(target_path, "**/*.jsonl") + jsonl_files = glob.glob(search_pattern, recursive=True) + jsonl_files = [f for f in jsonl_files if not os.path.basename(f).startswith("eval_mbpp_")] + else: + jsonl_files = [target_path] + + if not jsonl_files: + print(f"Error: 在路径 {target_path} 及其子目录下未找到任何 .jsonl 文件。") + return + + try: + code_eval = hf_evaluate.load("code_eval") + except: + print("Error: Could not load code_eval. Ensure 'evaluate' and 'code_eval' are installed.") + return + + for file_path in jsonl_files: + print(f"\n>>> 正在评测 MBPP 文件: {file_path}") + all_candidate_predictions = [] + all_voted_predictions = [] + all_references = [] + detailed_results = [] + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + item = json.loads(line) + + doc = item.get("doc", {}) + test_list = doc.get("test_list", []) + test_setup = doc.get("test_setup_code", "") + full_reference = (test_setup + "\n" + "\n".join(test_list)).strip() + + item_nfe = item.get("nfe", 0) + item_svf = item.get("svf_calls", 0) + nfe_list.append(item_nfe) + svf_list.append(item_svf) + + resps = item.get("resps", []) + trajs = item.get("all_trajectories", []) + + candidate_stats = {} + processed_candidates = [] + + source_data = trajs if trajs else resps + for idx, entry in enumerate(source_data): + raw_text = entry.get("resp", "") if isinstance(entry, dict) else (entry[0] if isinstance(entry, list) else entry) + score = entry.get("score", 0) if isinstance(entry, dict) else 0 + + code = extract_python_code(raw_text) + if not code: continue + + processed_candidates.append(code) + + try: + ast.parse(code) + is_valid = True + except: + is_valid = False + + norm = normalize_code_for_voting(code) + if norm not in candidate_stats: + candidate_stats[norm] = {"count": 0, "valid": is_valid, "code": code, "max_score": -float('inf')} + candidate_stats[norm]["count"] += 1 + candidate_stats[norm]["max_score"] = max(candidate_stats[norm]["max_score"], score) + + if not candidate_stats: + voted_code = "" + else: + sorted_norms = sorted( + candidate_stats.keys(), + key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["max_score"], candidate_stats[k]["count"]), + reverse=True + ) + voted_code = candidate_stats[sorted_norms[0]]["code"] + + all_candidate_predictions.append(processed_candidates if processed_candidates else [""]) + all_voted_predictions.append([voted_code]) + all_references.append(full_reference) + + detailed_results.append({ + "task_id": doc.get("task_id", "N/A"), + "voted_code": voted_code, + "nfe": item_nfe, + "svf_calls": item_svf, + "candidates_count": len(processed_candidates) + }) + + if not all_voted_predictions: + continue + + print(f"正在测试代码 (共 {len(all_voted_predictions)} 题)...") + res_voted, details_voted = code_eval.compute(references=all_references, predictions=all_voted_predictions, k=[1]) + res_pk, details_pk = code_eval.compute(references=all_references, predictions=all_candidate_predictions, k=[1]) + + acc_voted = res_voted.get("pass@1", 0.0) * 100 + acc_pk = res_pk.get("pass@1", 0.0) * 100 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- Pass@1: {acc_voted:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + for i, detail in enumerate(detailed_results): + detail["is_voted_correct"] = details_voted.get(i, [[0, {"passed": False}]])[0][1]["passed"] + + file_dir = os.path.dirname(file_path) + base_name = os.path.basename(file_path) + output_name = f"eval_mbpp_{base_name.replace('.jsonl', '.json')}" + output_path = os.path.join(file_dir, output_name) + + final_report = { + "summary": { + "pass_at_1": f"{acc_voted:.2f}%", + "avg_nfe": avg_nfe, + "avg_svf": avg_svf + }, + "details": detailed_results + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + print(f"成功保存结果至: {output_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/requirements.txt b/Prism/LLaDA2mini/LLaDA2mini_Baseline/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6ae6174b97bc14baecfc1f884ce4881f62558633 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/requirements.txt @@ -0,0 +1,9 @@ +sacrebleu +evaluate +datasets +numpy +pandas +tqdm +regex +sqlitedict +pytablewriter \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_gsm8k.sh b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_gsm8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..57bb4076f47036e4c853202e998e26dac3a45cd5 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_gsm8k.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/baseline_gsm8k" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=256 +STEPS=32 +BLOCK=32 +TASK="gsm8k" +TYPE="math" +NAME="baseline_n1" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=${TYPE},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \ + --num_fewshot 0 \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_humaneval.sh b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_humaneval.sh new file mode 100644 index 0000000000000000000000000000000000000000..ab930d729744e2bfb7b34ee7cdeaf430c6ff02aa --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_humaneval.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/baseline_humaneval" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=512 +STEPS=32 +BLOCK=32 +TASK="humaneval" +TYPE="code" +NAME="baseline_n1" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=${TYPE},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \ + --num_fewshot 0 \ + --confirm_run_unsafe_code \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_math500.sh b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_math500.sh new file mode 100644 index 0000000000000000000000000000000000000000..e9cd72d244ab193707d84b7bda76df59e24ab05d --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_math500.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/baseline_math500" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=256 +STEPS=32 +BLOCK=32 +TASK="math500" +TYPE="math" +NAME="baseline_n1" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=${TYPE},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \ + --num_fewshot 0 \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_mbpp.sh b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_mbpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..9759c70b264fc07c503f71e2fa3cd54cefe2993d --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Baseline/scripts/run_mbpp.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/baseline_mbpp" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=512 +STEPS=32 +BLOCK=32 +TASK="mbpp" +TYPE="code" +NAME="baseline_n1" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=1,hts_mode=False,steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},task_type=${TYPE},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/baseline.jsonl" \ + --num_fewshot 0 \ + --confirm_run_unsafe_code \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/.gitignore b/Prism/LLaDA2mini/LLaDA2mini_Prism/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..06fcf0c6ecee82cc5fe808575c4af69b9527fdb6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/.gitignore @@ -0,0 +1,210 @@ +*.jsonl +*.json + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/LICENSE b/Prism/LLaDA2mini/LLaDA2mini_Prism/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0f363b42d00f2a291c617c43e6fc3a9f142729be --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 preordinary + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c50ad3edd2cb1dea76048d416624a7c7db7c3209 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/__init__.py @@ -0,0 +1,7 @@ +import logging +import os + +from .evaluator import evaluate, simple_evaluate + + +__version__ = "0.4.9" diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/__main__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9d7ccf4a174c7c41fab5148a8851353a7e7eb6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/__main__.py @@ -0,0 +1,527 @@ +import argparse +import json +import logging +import os +import sys +from functools import partial +from pathlib import Path +from typing import Union + +from dllm_eval import evaluator, utils +from dllm_eval.evaluator import request_caching_arg_to_dict +from dllm_eval.loggers import EvaluationTracker, WandbLogger +from dllm_eval.tasks import TaskManager +from dllm_eval.utils import ( + handle_non_serializable, + make_table, + simple_parse_args_string, +) + + +def try_parse_json(value: str) -> Union[str, dict, None]: + if value is None: + return None + try: + return json.loads(value) + except json.JSONDecodeError: + if "{" in value: + raise argparse.ArgumentTypeError( + f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings." + ) + return value + + +def _int_or_none_list_arg_type( + min_len: int, max_len: int, defaults: str, value: str, split_char: str = "," +): + def parse_value(item): + item = item.strip().lower() + if item == "none": + return None + try: + return int(item) + except ValueError: + raise argparse.ArgumentTypeError(f"{item} is not an integer or None") + + items = [parse_value(v) for v in value.split(split_char)] + num_items = len(items) + + if num_items == 1: + # Makes downstream handling the same for single and multiple values + items = items * max_len + elif num_items < min_len or num_items > max_len: + raise argparse.ArgumentTypeError( + f"Argument requires {max_len} integers or None, separated by '{split_char}'" + ) + elif num_items != max_len: + logging.warning( + f"Argument requires {max_len} integers or None, separated by '{split_char}'. " + "Missing values will be filled with defaults." + ) + default_items = [parse_value(v) for v in defaults.split(split_char)] + items.extend( + default_items[num_items:] + ) # extend items list with missing defaults + + return items + + +def check_argument_types(parser: argparse.ArgumentParser): + """ + Check to make sure all CLI args are typed, raises error if not + """ + for action in parser._actions: + if action.dest != "help" and not action.const: + if action.type is None: + raise ValueError( + f"Argument '{action.dest}' doesn't have a type specified." + ) + else: + continue + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`" + ) + parser.add_argument( + "--tasks", + "-t", + default=None, + type=str, + metavar="task1,task2", + help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above", + ) + parser.add_argument( + "--model_args", + "-a", + default="", + type=try_parse_json, + help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""", + ) + parser.add_argument( + "--num_fewshot", + "-f", + type=int, + default=None, + metavar="N", + help="Number of examples in few-shot context", + ) + parser.add_argument( + "--batch_size", + "-b", + type=str, + default=1, + metavar="auto|auto:N|N", + help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=None, + metavar="N", + help="Maximal batch size to try with --batch_size auto.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to use (e.g. cuda, cuda:0, cpu).", + ) + parser.add_argument( + "--output_path", + "-o", + default=None, + type=str, + metavar="DIR|DIR/file.json", + help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", + ) + parser.add_argument( + "--limit", + "-L", + type=float, + default=None, + metavar="N|0 argparse.Namespace: + check_argument_types(parser) + return parser.parse_args() + + +def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: + if not args: + # we allow for args to be passed externally, else we parse them ourselves + parser = setup_parser() + args = parse_eval_args(parser) + + if args.wandb_args: + wandb_args_dict = simple_parse_args_string(args.wandb_args) + wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args) + wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict) + + utils.setup_logging(args.verbosity) + eval_logger = logging.getLogger(__name__) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # update the evaluation tracker args with the output path and the HF token + if args.output_path: + args.hf_hub_log_args += f",output_path={args.output_path}" + if os.environ.get("HF_TOKEN", None): + args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}" + evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args) + evaluation_tracker = EvaluationTracker(**evaluation_tracker_args) + + if args.predict_only: + args.log_samples = True + if (args.log_samples or args.predict_only) and not args.output_path: + raise ValueError( + "Specify --output_path if providing --log_samples or --predict_only" + ) + + if args.fewshot_as_multiturn and args.apply_chat_template is False: + raise ValueError( + "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." + ) + + if args.include_path is not None: + eval_logger.info(f"Including path: {args.include_path}") + metadata = ( + simple_parse_args_string(args.model_args) + if isinstance(args.model_args, str) + else args.model_args + if isinstance(args.model_args, dict) + else {} + ) | ( + args.metadata + if isinstance(args.metadata, dict) + else simple_parse_args_string(args.metadata) + ) + + task_manager = TaskManager(include_path=args.include_path, metadata=metadata) + + if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples: + eval_logger.warning( + "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub." + ) + + if args.limit: + eval_logger.warning( + " --limit SHOULD ONLY BE USED FOR TESTING." + "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." + ) + if args.samples: + assert args.limit is None, ( + "If --samples is not None, then --limit must be None." + ) + if (samples := Path(args.samples)).is_file(): + args.samples = json.loads(samples.read_text()) + else: + args.samples = json.loads(args.samples) + + if args.tasks is None: + eval_logger.error("Need to specify task to evaluate.") + sys.exit() + elif args.tasks == "list": + print(task_manager.list_all_tasks()) + sys.exit() + elif args.tasks == "list_groups": + print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) + sys.exit() + elif args.tasks == "list_tags": + print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) + sys.exit() + elif args.tasks == "list_subtasks": + print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) + sys.exit() + else: + if os.path.isdir(args.tasks): + import glob + + task_names = [] + yaml_path = os.path.join(args.tasks, "*.yaml") + for yaml_file in glob.glob(yaml_path): + config = utils.load_yaml_config(yaml_file) + task_names.append(config) + else: + task_list = args.tasks.split(",") + task_names = task_manager.match_tasks(task_list) + for task in [task for task in task_list if task not in task_names]: + if os.path.isfile(task): + config = utils.load_yaml_config(task) + task_names.append(config) + task_missing = [ + task for task in task_list if task not in task_names and "*" not in task + ] # we don't want errors if a wildcard ("*") task name was used + + if task_missing: + missing = ", ".join(task_missing) + eval_logger.error( + f"Tasks were not found: {missing}\n" + f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks", + ) + raise ValueError( + f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues." + ) + + # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args + if args.trust_remote_code: + eval_logger.info( + "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`" + ) + # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, + # because it's already been determined based on the prior env var before launching our + # script--`datasets` gets imported by dllm_eval internally before these lines can update the env. + import datasets + + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True + + args.model_args = args.model_args + ",trust_remote_code=True" + ( + eval_logger.info(f"Selected Tasks: {task_names}") + if eval_logger.getEffectiveLevel() >= logging.INFO + else print(f"Selected Tasks: {task_names}") + ) + + request_caching_args = request_caching_arg_to_dict( + cache_requests=args.cache_requests + ) + + results = evaluator.simple_evaluate( + model=args.model, + model_args=args.model_args, + tasks=task_names, + num_fewshot=args.num_fewshot, + batch_size=args.batch_size, + max_batch_size=args.max_batch_size, + device=args.device, + use_cache=args.use_cache, + limit=args.limit, + samples=args.samples, + check_integrity=args.check_integrity, + write_out=args.write_out, + log_samples=args.log_samples, + evaluation_tracker=evaluation_tracker, + system_instruction=args.system_instruction, + apply_chat_template=args.apply_chat_template, + fewshot_as_multiturn=args.fewshot_as_multiturn, + gen_kwargs=args.gen_kwargs, + task_manager=task_manager, + predict_only=args.predict_only, + random_seed=args.seed[0], + numpy_random_seed=args.seed[1], + torch_random_seed=args.seed[2], + fewshot_random_seed=args.seed[3], + confirm_run_unsafe_code=args.confirm_run_unsafe_code, + metadata=metadata, + **request_caching_args, + ) + + if results is not None: + if args.log_samples: + samples = results.pop("samples") + dumped = json.dumps( + results, indent=2, default=handle_non_serializable, ensure_ascii=False + ) + if args.show_config: + print(dumped) + + batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) + + # Add W&B logging + if args.wandb_args: + try: + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + if args.log_samples: + wandb_logger.log_eval_samples(samples) + except Exception as e: + eval_logger.info(f"Logging to Weights and Biases failed due to {e}") + + evaluation_tracker.save_results_aggregated( + results=results, samples=samples if args.log_samples else None + ) + + if args.log_samples: + for task_name, config in results["configs"].items(): + evaluation_tracker.save_results_samples( + task_name=task_name, samples=samples[task_name] + ) + + if ( + evaluation_tracker.push_results_to_hub + or evaluation_tracker.push_samples_to_hub + ): + evaluation_tracker.recreate_metadata_card() + + print( + f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " + f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" + ) + + if args.wandb_args: + # Tear down wandb run once all the logging is done. + wandb_logger.run.finish() + + +if __name__ == "__main__": + cli_evaluate() diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/filter.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..bddbf3ab8d1bcbba804f9790ef0290d437bcde69 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/filter.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Iterable, List, Union + +from dllm_eval.api.instance import Instance + + +class Filter(ABC): + """ + Filter classes operate on a per-task level. + They take all model outputs (`instance.resps` for all `task.instances`) + across all instances of a task, and perform operations. + In a single run, one can configure any number of separate filters or lists of filters. + + """ + + def __init__(self, **kwargs) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + @abstractmethod + def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: + """ + Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. + Should return the list of (filtered) response lists *in the same order as they were input*, e.g. + if pass in [, ] should return + [, ] + """ + return resps + + +@dataclass +class FilterEnsemble: + """ + FilterEnsemble creates a pipeline applying multiple filters. + Its intended usage is to stack multiple post-processing steps in order. + `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each + pipeline separately. + """ + + name: str + filters: List[Callable[[], Filter]] + + def apply(self, instances: List[Instance]) -> None: + resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) + resps, docs = list(resps), list(docs) + + for f in self.filters: + # apply filters in sequence + resps = f().apply(resps, docs) + + # add the end results after filtering to filtered_requests of their respective source instances. + # has key `self.name`: each FilterEnsemble applied in a given run should use a different name. + for inst, resp in zip(instances, resps): + inst.filtered_resps[self.name] = resp diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/group.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/group.py new file mode 100644 index 0000000000000000000000000000000000000000..0c60739bbd26c79ecab91f54240798b2ae9e3313 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/group.py @@ -0,0 +1,115 @@ +import abc +from dataclasses import asdict, dataclass +from inspect import getsource +from typing import Any, Callable, List, Optional, Union + + +@dataclass +class AggMetricConfig(dict): + metric: Optional[str] = None + aggregation: Optional[str] = "mean" + weight_by_size: Optional[str] = False + # list of filter names which should be incorporated into the aggregated metric. + filter_list: Optional[Union[str, list]] = "none" + + def __post_init__(self): + if self.aggregation != "mean" and not callable(self.aggregation): + raise ValueError( + f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'." + ) + + if isinstance(self.filter_list, str): + self.filter_list = [self.filter_list] + + +@dataclass +class GroupConfig(dict): + group: Optional[str] = None + group_alias: Optional[str] = None + task: Optional[Union[str, list]] = None + aggregate_metric_list: Optional[ + Union[List[AggMetricConfig], AggMetricConfig, dict] + ] = None + metadata: Optional[dict] = ( + None # by default, not used in the code. allows for users to pass arbitrary info to tasks + ) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + return setattr(self, item, value) + + def __post_init__(self): + if self.aggregate_metric_list is not None: + if isinstance(self.aggregate_metric_list, dict): + self.aggregate_metric_list = [self.aggregate_metric_list] + + self.aggregate_metric_list = [ + AggMetricConfig(**item) if isinstance(item, dict) else item + for item in self.aggregate_metric_list + ] + + def to_dict(self, keep_callable: bool = False) -> dict: + """dumps the current config as a dictionary object, as a printable format. + null fields will not be printed. + Used for dumping results alongside full task configuration + + :return: dict + A printable dictionary version of the TaskConfig object. + + # TODO: should any default value in the TaskConfig not be printed? + """ + cfg_dict = asdict(self) + # remove values that are `None` + for k, v in list(cfg_dict.items()): + if callable(v): + cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) + return cfg_dict + + def serialize_function( + self, value: Union[Callable, str], keep_callable=False + ) -> Union[Callable, str]: + """Serializes a given function or string. + + If 'keep_callable' is True, the original callable is returned. + Otherwise, attempts to return the source code of the callable using 'getsource'. + """ + if keep_callable: + return value + else: + try: + return getsource(value) + except (TypeError, OSError): + return str(value) + + +class ConfigurableGroup(abc.ABC): + def __init__( + self, + config: Optional[dict] = None, + ) -> None: + self._config = GroupConfig(**config) + + @property + def group(self): + return self._config.group + + @property + def group_alias(self): + return self._config.group_alias + + @property + def version(self): + return self._config.version + + @property + def config(self): + return self._config.to_dict() + + @property + def group_name(self) -> Any: + return self._config.group + + def __repr__(self): + return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})" diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/instance.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/instance.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c6afa0644e729ba441728c72a2469fdad07b8f --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/instance.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass, field +from typing import Literal, Optional, Tuple + + +OutputType = Literal[ + "loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice" +] + + +@dataclass +class Instance: + request_type: OutputType + doc: dict + arguments: tuple + idx: int + metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( + default_factory=lambda: (None, None, None) + ) + resps: list = field(default_factory=list) + filtered_resps: dict = field(default_factory=dict) + + # initialized after init + task_name: Optional[str] = None + doc_id: Optional[int] = None + repeats: Optional[int] = None + + def __post_init__(self) -> None: + # unpack metadata field + self.task_name, self.doc_id, self.repeats = self.metadata + + @property + def args(self): + """ + Returns (string,) where `string` is the string to calculate loglikelihood over + """ + return ( + self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/metrics.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2aff6ce92a154a05df3d0bb7d28e09071cd12fbc --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/metrics.py @@ -0,0 +1,578 @@ +import logging +import math +import random +import re +import string +from collections.abc import Iterable +from typing import List + +import numpy as np +import sacrebleu + +from dllm_eval.api.registry import register_aggregation, register_metric + + +eval_logger = logging.getLogger(__name__) + + +# Register Aggregations First +@register_aggregation("bypass") +def bypass_agg(arr): + return 999 + + +@register_aggregation("nanmean") +def nanmean(arr): + if len(arr) == 0 or all(np.isnan(arr)): + return np.nan + return np.nanmean(arr) + + +@register_aggregation("mean") +def mean(arr): + return sum(arr) / len(arr) + + +@register_aggregation("median") +def median(arr): + return arr[len(arr) // 2] + + +# Certain metrics must be calculated across all documents in a benchmark. +# We use them as aggregation metrics, paired with no-op passthrough metric fns. +@register_aggregation("perplexity") +def perplexity(items): + return math.exp(-mean(items)) + + +@register_aggregation("weighted_perplexity") +def weighted_perplexity(items): + return math.exp(-weighted_mean(items)) + + +@register_aggregation("bits_per_byte") +def bits_per_byte(items): + return -weighted_mean(items) / math.log(2) + + +@register_aggregation("f1") +def f1_score(items): + from sklearn.metrics import f1_score + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + fscore = f1_score(golds, preds) + + return np.max(fscore) + + +@register_aggregation("matthews_corrcoef") +def matthews_corrcoef(items): + from sklearn.metrics import matthews_corrcoef + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + return matthews_corrcoef(golds, preds) + + +@register_aggregation("bleu") +def bleu(items): + """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric + for evaluating a generated sentence to a reference sentence. It counts matching + n-grams in the candidate translation to n-grams in the reference text, where + 1-gram or unigram would be each token and a bigram comparison would be each + word pair. The comparison is made regardless of word order + Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/ + Paper: https://www.aclweb.org/anthology/P02-1040/ + + Higher is better + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_bleu(preds, refs).score + + +@register_aggregation("chrf") +def chrf(items): + """chrF++ is a tool for automatic evaluation of machine translation output + based on character n-gram precision and recall enhanced with word n-grams. + Source: https://github.com/m-popovic/chrF + Paper: https://www.aclweb.org/anthology/W15-3049.pdf + + Higher is better # TODO I think + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_chrf(preds, refs).score + + +@register_aggregation("ter") +def ter(items): + """Translation Error Rate is an error metric for machine translation that + measures the number of edits required to change a system output into one + of the references + Source: http://www.cs.umd.edu/~snover/tercom/ + Paper: http://mt-archive.info/AMTA-2006-Snover.pdf + + Lower is better + """ + refs = list(zip(*items))[0] + preds = list(zip(*items))[1] + refs, preds = _sacreformat(refs, preds) + return sacrebleu.corpus_ter(preds, refs).score + + +@register_aggregation("brier_score") +def brier_score(items): # This is a passthrough function + gold, predictions = list(zip(*items)) + bs, num_class = np.array(predictions).shape + + gold = list(gold) + gold_one_hot = np.eye(num_class)[gold] + return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1)) + + +@register_metric( + metric="brier_score", + higher_is_better=False, + output_type=["multiple_choice"], + aggregation="brier_score", +) +def brier_score_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice"], + aggregation="mean", +) +def acc_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_norm", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice"], + aggregation="mean", +) +def acc_norm_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_mutual_info", + higher_is_better=True, + output_type="multiple_choice", + aggregation="mean", +) +def acc_mutual_info_fn(items): # This is a passthrough function + return items + + +### the code used in the `exact_match_hf_evaluate` function is ported from +### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py +### which is under the apache license. + +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +def exact_match_hf_evaluate( + predictions, + references, + regexes_to_ignore=None, + ignore_case=False, + ignore_punctuation=False, + ignore_numbers=False, +): + if regexes_to_ignore is not None: + for s in regexes_to_ignore: + predictions = np.array([re.sub(s, "", x) for x in predictions]) + references = np.array([re.sub(s, "", x) for x in references]) + else: + predictions = np.asarray(predictions) + references = np.asarray(references) + + if ignore_case: + predictions = np.char.lower(predictions) + references = np.char.lower(references) + + if ignore_punctuation: + repl_table = string.punctuation.maketrans("", "", string.punctuation) + predictions = np.char.translate(predictions, table=repl_table) + references = np.char.translate(references, table=repl_table) + + if ignore_numbers: + repl_table = string.digits.maketrans("", "", string.digits) + predictions = np.char.translate(predictions, table=repl_table) + references = np.char.translate(references, table=repl_table) + + score_list = predictions == references + + return {"exact_match": np.mean(score_list)} + + +### + + +@register_metric( + metric="exact_match", + higher_is_better=True, + output_type="generate_until", + aggregation="mean", +) +def exact_match_fn(**kwargs): + return exact_match_hf_evaluate(**kwargs) + + +@register_metric( + metric="perplexity", + higher_is_better=False, + output_type="loglikelihood", + aggregation="perplexity", +) +def perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="word_perplexity", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="weighted_perplexity", +) +def word_perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="byte_perplexity", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="weighted_perplexity", +) +def byte_perplexity_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="bits_per_byte", + higher_is_better=False, + output_type="loglikelihood_rolling", + aggregation="bits_per_byte", +) +def bits_per_byte_fn(items): # This is a passthrough function + return items + + +def pop_stddev(arr): + mu = mean(arr) + return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) + + +def sample_stddev(arr): + mu = mean(arr) + return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) + + +def mean_stderr(arr): + return sample_stddev(arr) / math.sqrt(len(arr)) + + +@register_metric( + metric="bypass", + higher_is_better=True, + output_type=["loglikelihood", "multiple_choice", "generate_until"], + aggregation="bypass", +) +def bypass(items): + return None + + +@register_metric( + metric="mcc", + higher_is_better=True, + output_type="multiple_choice", + aggregation="matthews_corrcoef", +) +def mcc_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="f1", + higher_is_better=True, + output_type="multiple_choice", + aggregation="f1", +) +def f1_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="bleu", + higher_is_better=True, + output_type="generate_until", + aggregation="bleu", +) +def bleu_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="chrf", + higher_is_better=True, + output_type="generate_until", + aggregation="chrf", +) +def chrf_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="ter", + higher_is_better=True, + output_type="generate_until", + aggregation="ter", +) +def ter_fn(items): # This is a passthrough function + return items + + +@register_metric( + metric="acc_all", + higher_is_better=True, + output_type="loglikelihood", + aggregation="mean", +) +def acc_all(items): + # Only count as correct if all answers are labeled correctly for each question + question_scoring_dict = {} + preds = list(zip(*items))[0] + docs = list(zip(*items))[1] + + for doc, pred in zip(docs, preds): + paragraph_id = doc["idx"]["paragraph"] + question_id = doc["idx"]["question"] + if (paragraph_id, question_id) not in question_scoring_dict: + question_scoring_dict[(paragraph_id, question_id)] = [] + + gold_label = doc["label"] == 1 + + question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) + acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) + return acc + + +def acc_all_stderr(items): + # Only count as correct if all answers are labeled correctly for each question + question_scoring_dict = {} + preds = list(zip(*items))[0] + docs = list(zip(*items))[1] + + for doc, pred in zip(docs, preds): + question_id = doc["idx"]["question"] + if question_id not in question_scoring_dict: + question_scoring_dict[question_id] = [] + + gold_label = doc["label"] == 1 + question_scoring_dict[question_id].append(gold_label == pred) + + acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()]) + return acc + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + """Compute max metric between prediction and each ground truth.""" + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def weighted_mean(items): + a, b = zip(*items) + return sum(a) / sum(b) + + +def is_non_str_iterable(obj): + return isinstance(obj, Iterable) and not isinstance(obj, str) + + +def _sacreformat(refs, preds): + """Format refs and preds for sacrebleu corpus calculation. It is very particular""" + # Sacrebleu expects (List[str], List[List[str]) + # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) + + # Note [ref1_stream] is the first reference for each pred. + # So lists are size N and (M, N) for N preds and M possible refs for each pred + # This is a different order of dimensions that I would expect + + # We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds + # Must become List[List[str]] with the inner list corresponding to preds + if not is_non_str_iterable(refs): + refs = list(refs) + if not is_non_str_iterable(refs[0]): + refs = [[ref] for ref in refs] + refs = list(zip(*refs)) + # Note the number of refs in each ref list much match the number of preds + + # We expect preds to be List[str] or List[List[str]]. Must become List[str] + if not is_non_str_iterable(preds): + preds = list(preds) + if is_non_str_iterable(preds[0]): + assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}" + preds = [pred[0] for pred in preds] + + return refs, preds + + +# stderr stuff + + +class _bootstrap_internal: + def __init__(self, f, n) -> None: + self.f = f + self.n = n + + def __call__(self, v): + i, xs = v + rnd = random.Random() + rnd.seed(i) + res = [] + for _ in range(self.n): + res.append(self.f(rnd.choices(xs, k=len(xs)))) + return res + + +def bootstrap_stderr(f, xs, iters): + import multiprocessing as mp + + pool = mp.Pool(mp.cpu_count()) + # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something + # equivalent to stderr calculated without Bessel's correction in the stddev. + # Unfortunately, I haven't been able to figure out what the right correction is + # to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but + # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) + # Thankfully, shouldn't matter because our samples are pretty big usually anyways + res = [] + chunk_size = min(1000, iters) + from tqdm import tqdm + + print("bootstrapping for stddev:", f.__name__) + for bootstrap in tqdm( + pool.imap( + _bootstrap_internal(f, chunk_size), + [(i, xs) for i in range(iters // chunk_size)], + ), + total=iters // chunk_size, + ): + # sample w replacement + res.extend(bootstrap) + + pool.close() + return sample_stddev(res) + + +def stderr_for_metric(metric, bootstrap_iters: int): + if bootstrap_iters <= 0: + # return no function (don't compute stderr) if bootstrap iters = 0 + return None + + bootstrappable = [ + median, + matthews_corrcoef, + f1_score, + perplexity, + bleu, + chrf, + ter, + nanmean, + ] + + if metric in bootstrappable: + return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters) + + stderr = {mean: mean_stderr, acc_all: acc_all_stderr} + + return stderr.get(metric, None) + + +def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): + # Used to aggregate bootstrapped stderrs across subtasks in a group, + # when we are weighting by the size of each subtask. + # + + assert len(stderrs) == len(sizes) + + # formula source: https://en.wikipedia.org/wiki/Pooled_variance + # and: https://stats.stackexchange.com/a/4841331 + # this empirically seems to match running `stderr_for_metric` on all instances + # from the subtasks concatenated with each other. + pooled_sample_var = ( + sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)]) + ) / (sum(sizes) - len(sizes)) + + return np.sqrt(pooled_sample_var / sum(sizes)) + + +def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): + assert metrics is not None, ( + "Need to pass a list of each subtask's metric for this stderr aggregation" + ) + assert len(stderrs) == len(sizes) and len(sizes) == len(metrics) + + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation. + # This formula depends on sample means. + # removed because it seems to give erroneously huge stderrs for groupings of tasks + # and does not seem to match up with bootstrap-calculated stderrs for groups. + + ### don't use this unless a statistician has told you it's the right thing to do ### + + # accumulators: we'll aggregate pairwise N - 1 times + variance = stderrs[0] ** 2 + curr_size = sizes[0] + curr_score = metrics[0] + + for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]): + curr_score = ((curr_score * curr_size) + (score * size)) / ( + curr_size + size + ) # NOTE: this assumes our aggregation fn is "mean" + + variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / ( + curr_size + size - 1 + ) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * ( + curr_score - score + ) ** 2 + + return np.sqrt(variance) + + +def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): + # A helper function that is used to aggregate + # subtask scores cross-task. + # TODO: does not hold for non-mean aggregations + if not weight_by_size: + sizes = [1] * len(sizes) + + assert len(metrics) == len(sizes) + + return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/model.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9364a9312d78c1029e5edf38d61f192afca91334 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/model.py @@ -0,0 +1,493 @@ +import abc +import hashlib +import json +import logging +import os +from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union + +import transformers +from sqlitedict import SqliteDict +from tqdm import tqdm + +from dllm_eval import utils + + +eval_logger = logging.getLogger(__name__) + +T = TypeVar("T", bound="LM") + + +class LM(abc.ABC): + def __init__(self) -> None: + """Defines the interface that should be implemented by all LM subclasses. + LMs are assumed to take text (strings) as input and yield strings as output + (inputs/outputs should be tokenization-agnostic.) + + """ + # set rank and world size to a single process, by default. + self._rank = 0 + self._world_size = 1 + self.cache_hook = CacheHook(None) + + @abc.abstractmethod + def loglikelihood(self, requests) -> List[Tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + pass + + @abc.abstractmethod + def loglikelihood_rolling(self, requests) -> List[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: BOS/EOS + Max context length: 4 + Resulting input/prediction pairs: + + INPUT: BOS 0 1 2 + PRED: 0 1 2 3 + + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + + INPUT: 5 6 7 8 + PRED: 8 9 + + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the BOS/EOS token. + Can also be overridden for custom cases by `prefix_token_id`. + """ + pass + + # TODO: Add an optional max length + @abc.abstractmethod + def generate_until(self, requests) -> List[str]: + """Generate greedily until a stopping sequence + + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs). + context: str + Context string + gen_kwargs: dict + A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc. + :return: list[str] + A list of model generated continuations. + continuation: str + The generated continuation. + """ + pass + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt=True + ) -> str: + """ + Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. + + :param chat_history: list[dict[str, str]] + A list of dictionaries with keys 'role' and 'content'. + Values are strings representing the role name and the content of the message, respectively. + :param add_generation_prompt: bool + Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message. + :return: str + A string representing the chat history in a format that can be used as input to the LM. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'apply_chat_template' method for your model type." + ) + + @classmethod + def create_from_arg_string( + cls: Type[T], arg_string: str, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given argument string and additional config. + + Parameters: + - arg_string: A string containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + additional_config = {} if additional_config is None else additional_config + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + @classmethod + def create_from_arg_obj( + cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None + ) -> T: + """ + Creates an instance of the LM class using the given arg_obj + + Parameters: + - arg_obj: A dict containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the LM class. + """ + + additional_config = {} if additional_config is None else additional_config + additional_config = { + k: v for k, v in additional_config.items() if v is not None + } + + return cls(**arg_dict, **additional_config) + + @property + def rank(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._rank + + @property + def world_size(self): + # used in the case of parallelism. Hardcoded to + # ensure no errors arise using API models which do + # not support multi-device parallelism nor expect it. + return self._world_size + + @property + def tokenizer_name(self) -> str: + """Must be defined for LM subclasses which implement Chat Templating. + Should return the name of the tokenizer or chat template used. + Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. + """ + raise NotImplementedError( + "To use this model with chat templates, please implement the 'tokenizer_name' property." + ) + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """Returns the chat template structure for user/assistant messages if a template is provided. + This method is intended to be overridden in a subclass to define a specific chat template format. + For models that do not support chat templates, this method returns None by default. + """ + + return "" + + def set_cache_hook(self, cache_hook) -> None: + self.cache_hook = cache_hook + + +### SQLite-based caching of LM responses +def hash_args(attr, args): + dat = json.dumps([attr] + list(args)) + return hashlib.sha256(dat.encode("utf-8")).hexdigest() + + +class CacheHook: + def __init__(self, cachinglm) -> None: + if cachinglm is None: + self.dbdict = None + return + + self.dbdict = cachinglm.dbdict + + def add_partial(self, attr, req, res) -> None: + if self.dbdict is None: + return + hsh = hash_args(attr, req) + self.dbdict[hsh] = res + + +class CachingLM: + def __init__(self, lm, cache_db) -> None: + """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. + + :param lm: LM + Underlying LM + :param cache_db: str + Path to cache db + """ + self.lm = lm + self.cache_db = cache_db + if os.path.dirname(cache_db): + os.makedirs(os.path.dirname(cache_db), exist_ok=True) + self.dbdict = SqliteDict(cache_db, autocommit=True) + + # add hook to lm + lm.set_cache_hook(self.get_cache_hook()) + + def __getattr__(self, attr: str): + lm_attr = getattr(self.lm, attr) + if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]: + eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") + return lm_attr + + def fn(requests): + res = [] + remaining_reqs = [] + warned = False + # figure out which ones are cached and which ones are new + eval_logger.info( + f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..." + ) + for req in tqdm(requests, desc="Checking cached requests"): + hsh = hash_args(attr, req.args) + if attr == "generate_until" and req.args[1].get("do_sample", False): + # when we are doing non-greedy generation, don't use the cache + # (else every "randomly sampled" generation would be identical for repeats > 1). + if not warned: + eval_logger.warning( + f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." + ) + warned = True + res.append(None) + remaining_reqs.append(req) + elif hsh in self.dbdict: + ob = self.dbdict[hsh] + + assert ob is not None + + res.append(ob) + else: + res.append(None) + remaining_reqs.append(req) + eval_logger.info( + f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" + ) + if remaining_reqs: + # actually run the LM on the requests that do not have cached results + rem_res = getattr(self.lm, attr)(remaining_reqs) + else: + rem_res = [] + + # stick the new ones back into the list and also cache any of the new ones + resptr = 0 + for req, r in zip(remaining_reqs, rem_res): + while res[resptr] is not None: + resptr += 1 + + res[resptr] = r + + # caching + hsh = hash_args(attr, req.args) + self.dbdict[hsh] = r + self.dbdict.commit() + + return res + + return fn + + def get_cache_hook(self): + return CacheHook(self) + + +class TemplateLM(LM): + """ + A class acting as intermediary between the LM base class + and boilerplate often included in other LM subclasses. + """ + + tokenizer = None + + @property + @abc.abstractmethod + def eot_token_id(self): + pass + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + return self.eot_token_id + + @abc.abstractmethod + def tok_encode(self, string: str, **kwargs) -> List[int]: + """ + Tokenize a string using the model's tokenizer and return a list of token IDs. + """ + pass + + @abc.abstractmethod + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + pass + + def _encode_pair( + self, context: str, continuation: str + ) -> Tuple[List[int], List[int]]: + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + model_class = getattr(self, "AUTO_MODEL_CLASS", None) + + if model_class == transformers.AutoModelForSeq2SeqLM: + context_enc = self.tok_encode(context) + continuation_enc = self.tok_encode(continuation, add_special_tokens=False) + else: + whole_enc = self.tok_encode(context + continuation) + context_enc = self.tok_encode(context) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood( + self, requests, disable_tqdm: bool = False + ) -> List[Tuple[float, bool]]: + new_reqs = [] + for context, continuation in [req.args for req in requests]: + if context == "": + # BOS or EOS as context + context_enc, continuation_enc = ( + [self.prefix_token_id], + self.tok_encode(continuation), + ) + else: + context_enc, continuation_enc = self._encode_pair(context, continuation) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm) + + @abc.abstractmethod + def loglikelihood_rolling( + self, requests, disable_tqdm: bool = False + ) -> List[float]: + pass + + @abc.abstractmethod + def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: + pass + + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: + """ + Set and get the appropriate chat template for the model. + This method sets the tokenizer's chat_template and returns the template string for reproducibility. + + The template selection logic is adapted from the Transformers library's `apply_chat_template` + method in the Tokenizer class. The original implementation can be found at: + https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687 + + This method ensures that the right template is chosen based on the following: + 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string. + 1. If the model's tokenizer has multiple templates: + a. Use the specified template if it exists in the dictionary. + b. Use the default template from the list if no specific template is provided. + c. Raise an error if no default template exists and no specific template is provided. + 2. If the model's tokenizer has a single template or no template: + a. Use the tokenizer's chat template if available. + b. Fall back to the default chat template if no tokenizer chat template exists. + + Args: + chat_template (Union[bool, str]): Specifies the chat template to use. + - If False or None, no template is applied. + - If True, the default or only available template is used. + - If a string, the template with the matching name is used. + + Returns: + Optional[str]: The selected chat template, or None if no template is applied. + """ + if self.tokenizer is None: + return "" + + if chat_template is False or chat_template is None: + eval_logger.warning( + "model.chat_template was called with the chat_template set to False or None. " + "Therefore no chat template will be applied. Make sure this is an intended behavior." + ) + return None + + # Convert boolean chat_template to None to ensure compatibility with the adapted logic + if isinstance(chat_template, bool): + chat_template = None + using_default_template = False + + # First, handle the cases when the model has a dict of multiple templates + try: + template = ( + self.tokenizer.chat_template or self.tokenizer.default_chat_template + ) + except AttributeError: + return None + + if isinstance(template, dict): + using_default_dict = self.tokenizer.chat_template is None + + if chat_template is not None: + if chat_template in template: + selected_template = template[chat_template] + if using_default_dict: + using_default_template = True + else: + raise ValueError( + f"The specified chat template '{chat_template}' is not available. " + f"Available template names are {sorted(template.keys())}." + ) + else: + # If user didn't pass a chat template, use the default template from the dict + if "default" in template: + selected_template = template["default"] + using_default_template = True + else: + raise ValueError( + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template.keys())}." + ) + + # Cases when the model has a single template or no template + else: + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template + if isinstance(chat_template, str): + eval_logger.warning( + "Chat template name provided, but the tokenizer's chat template is not a dictionary. " + "Using the tokenizer's chat template or the default template instead." + ) + if self.tokenizer.chat_template is not None: + selected_template = self.tokenizer.chat_template + else: + selected_template = self.tokenizer.default_chat_template + using_default_template = True + + if using_default_template: + eval_logger.warning( + "No chat template is set for this tokenizer, falling back to a default class-level template. This is " + "very error-prone, because models are often trained with templates different from the class default! " + "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " + "point any code depending on them will stop working. We recommend setting a valid chat template before " + "then to ensure that this model continues working without issues." + ) + + return selected_template diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/registry.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2b2e415a0a19862a41bde307bbad2e6ba326f5 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/registry.py @@ -0,0 +1,196 @@ +import logging +from typing import Callable, Dict, Union + +import evaluate as hf_evaluate + +from dllm_eval.api.model import LM + + +eval_logger = logging.getLogger(__name__) + +MODEL_REGISTRY = {} + + +def register_model(*names): + # either pass a list or a single alias. + # function receives them as a tuple of strings + + def decorate(cls): + for name in names: + assert issubclass(cls, LM), ( + f"Model '{name}' ({cls.__name__}) must extend LM class" + ) + + assert name not in MODEL_REGISTRY, ( + f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." + ) + + MODEL_REGISTRY[name] = cls + return cls + + return decorate + + +def get_model(model_name): + try: + return MODEL_REGISTRY[model_name] + except KeyError: + raise ValueError( + f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" + ) + + +TASK_REGISTRY = {} +GROUP_REGISTRY = {} +ALL_TASKS = set() +func2task_index = {} + + +def register_task(name): + def decorate(fn): + assert name not in TASK_REGISTRY, ( + f"task named '{name}' conflicts with existing registered task!" + ) + + TASK_REGISTRY[name] = fn + ALL_TASKS.add(name) + func2task_index[fn.__name__] = name + return fn + + return decorate + + +def register_group(name): + def decorate(fn): + func_name = func2task_index[fn.__name__] + if name in GROUP_REGISTRY: + GROUP_REGISTRY[name].append(func_name) + else: + GROUP_REGISTRY[name] = [func_name] + ALL_TASKS.add(name) + return fn + + return decorate + + +OUTPUT_TYPE_REGISTRY = {} +METRIC_REGISTRY = {} +METRIC_AGGREGATION_REGISTRY = {} +AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} +HIGHER_IS_BETTER_REGISTRY = {} +FILTER_REGISTRY = {} + +DEFAULT_METRIC_REGISTRY = { + "loglikelihood": [ + "perplexity", + "acc", + ], + "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], + "multiple_choice": ["acc", "acc_norm"], + "generate_until": ["exact_match"], +} + + +def register_metric(**args): + # TODO: do we want to enforce a certain interface to registered metrics? + def decorate(fn): + assert "metric" in args + name = args["metric"] + + for key, registry in [ + ("metric", METRIC_REGISTRY), + ("higher_is_better", HIGHER_IS_BETTER_REGISTRY), + ("aggregation", METRIC_AGGREGATION_REGISTRY), + ]: + if key in args: + value = args[key] + assert value not in registry, ( + f"{key} named '{value}' conflicts with existing registered {key}!" + ) + + if key == "metric": + registry[name] = fn + elif key == "aggregation": + registry[name] = AGGREGATION_REGISTRY[value] + else: + registry[name] = value + + return fn + + return decorate + + +def get_metric(name: str, hf_evaluate_metric=False) -> Callable: + if not hf_evaluate_metric: + if name in METRIC_REGISTRY: + return METRIC_REGISTRY[name] + else: + eval_logger.warning( + f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." + ) + + try: + metric_object = hf_evaluate.load(name) + return metric_object.compute + except Exception: + eval_logger.error( + f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", + ) + + +def register_aggregation(name: str): + def decorate(fn): + assert name not in AGGREGATION_REGISTRY, ( + f"aggregation named '{name}' conflicts with existing registered aggregation!" + ) + + AGGREGATION_REGISTRY[name] = fn + return fn + + return decorate + + +def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: + try: + return AGGREGATION_REGISTRY[name] + except KeyError: + eval_logger.warning(f"{name} not a registered aggregation metric!") + + +def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: + try: + return METRIC_AGGREGATION_REGISTRY[name] + except KeyError: + eval_logger.warning(f"{name} metric is not assigned a default aggregation!") + + +def is_higher_better(metric_name) -> bool: + try: + return HIGHER_IS_BETTER_REGISTRY[metric_name] + except KeyError: + eval_logger.warning( + f"higher_is_better not specified for metric '{metric_name}'!" + ) + + +def register_filter(name): + def decorate(cls): + if name in FILTER_REGISTRY: + eval_logger.info( + f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}" + ) + FILTER_REGISTRY[name] = cls + return cls + + return decorate + + +def get_filter(filter_name: Union[str, Callable]) -> Callable: + try: + return FILTER_REGISTRY[filter_name] + except KeyError as e: + if callable(filter_name): + return filter_name + else: + eval_logger.warning(f"filter `{filter_name}` is not registered!") + raise e diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/samplers.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..969789ef2111dcb8ee3b7eed4c69d54572d6c302 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/samplers.py @@ -0,0 +1,232 @@ +import logging +import warnings +from functools import partial +from typing import TYPE_CHECKING, Iterable, Optional, Union + +import datasets + + +if TYPE_CHECKING: + from random import Random + + from dllm_eval.api.task import ConfigurableTask, Task + +eval_logger = logging.getLogger("lm-eval") + + +class ContextSampler: + def __init__( + self, + docs: list[dict], + task: Union["Task", "ConfigurableTask"], + fewshot_indices: Optional[Iterable] = None, + rnd: Optional["Random"] = None, + ) -> None: + self.rnd = rnd + if not self.rnd: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!" + ) + + self.task = task + self.config = task._config + + self.target_delimiter = self.config.target_delimiter + self.fewshot_delimiter = self.config.fewshot_delimiter + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_text", None) is not None + ): + self.doc_to_text = partial( + self.task.doc_to_text, + doc_to_text=self.config.fewshot_config.get("doc_to_text", None), + ) + else: + self.doc_to_text = self.task.doc_to_text + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_target", None) is not None + ): + self.doc_to_target = partial( + self.task.doc_to_target, + doc_to_target=self.config.fewshot_config.get("doc_to_target", None), + ) + else: + self.doc_to_target = self.task.doc_to_target + + if ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("doc_to_choice", None) is not None + ): + self.doc_to_choice = partial( + self.task.doc_to_choice, + doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None), + ) + else: + self.doc_to_choice = self.task.doc_to_choice + + self.docs = docs # HF dataset split, provided by task._fewshot_docs() + if fewshot_indices: # subset few-shot docs from + if not isinstance(self.docs, datasets.Dataset): + raise ValueError( + "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously" + ) + self.docs = self.docs.select(fewshot_indices) + + def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None): + # draw an extra fewshot sample if using same split as evaluating on + prefix = gen_prefix + " " if gen_prefix else "" + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = "" + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + if self.config.doc_to_choice is None or isinstance(doc_content, str): + labeled_examples += doc_content + else: + labeled_examples += self.doc_to_choice(doc)[doc_content] + + if doc_target != "": + if self.target_delimiter.isspace() and str(doc_target)[0].isspace(): + # TODO: add logger warn once here. + warnings.warn( + "Both target_delimiter and target start with a space. This may cause issues.", + Warning, + stacklevel=2, + ) + labeled_examples += self.target_delimiter + labeled_examples += prefix + labeled_examples += ( + str(doc_target[0]) + if isinstance(doc_target, list) + else doc_target + if self.config.doc_to_choice is None or isinstance(doc_target, str) + else str(self.doc_to_choice(doc)[doc_target]) + ) + labeled_examples += self.fewshot_delimiter + + return labeled_examples + + def get_chat_context( + self, + doc: dict, + num_fewshot: int, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ): + # TODO: Do we need any other delimiter + prefix = gen_prefix + " " if gen_prefix else "" + chat_history = [] + # draw an extra fewshot sample if using same split as evaluating on + n_samples = ( + num_fewshot + 1 + if self.config.fewshot_split == self.config.test_split + else num_fewshot + ) + # draw `n_samples` docs from fewshot_docs + fewshotex = self.sample(n_samples) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + # TODO: should we just stop people from using fewshot from same split as evaluating? + selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] + + if fewshot_as_multiturn: + for doc in selected_docs: + doc_content = self.doc_to_text(doc) + doc_target = self.doc_to_target(doc) + chat_history.append( + { + "role": "user", + "content": doc_content + if self.config.doc_to_choice is None + or isinstance(doc_content, str) + else self.doc_to_choice(doc)[doc_content], + } + ) + chat_history.append( + { + "role": "assistant", + "content": prefix + str(doc_target[0]) + if isinstance(doc_target, list) + else prefix + doc_target + if self.config.doc_to_choice is None + or isinstance(doc_target, str) + else prefix + str(self.doc_to_choice(doc)[doc_target]), + } + ) + else: + # get fewshot context as one user turn + chat_history.append( + { + "role": "user", + "content": self.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ), + } + ) + + return chat_history + + def sample(self, n: int): + """ + Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. + """ + + return self.rnd.sample(self.docs, n) + + +class FirstNSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + Draw the first `n` samples in order from the specified split. + Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. + """ + assert n <= len(self.docs), ( + f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." + ) + return self.docs[:n] + + +class BalancedSampler(ContextSampler): + def sample(self, n: int) -> None: + """ + TODO: this should return approximately class-balanced samples from our fewshot examples. + TODO: what order should they be in? maybe random? + """ + + pass + + +class ManualSampler(ContextSampler): + def sample(self, n: int) -> None: + """ """ + pass + + +SAMPLER_REGISTRY = { + "default": ContextSampler, + "first_n": FirstNSampler, +} + + +def get_sampler(name: str): + try: + return SAMPLER_REGISTRY[name] + except KeyError: + raise ValueError( + f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}" + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/task.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/task.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6321af0b2b8777e0322745a9875656ec194190 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/api/task.py @@ -0,0 +1,1881 @@ +import abc +import ast +import logging +import random +import re +from collections.abc import Callable +from copy import deepcopy +from dataclasses import asdict, dataclass +from inspect import getsource +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + Union, +) + +import datasets +import numpy as np +from tqdm import tqdm + +from dllm_eval import utils +from dllm_eval.api import samplers +from dllm_eval.api.instance import Instance, OutputType +from dllm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity +from dllm_eval.api.registry import ( + AGGREGATION_REGISTRY, + DEFAULT_METRIC_REGISTRY, + get_aggregation, + get_metric, + get_metric_aggregation, + is_higher_better, +) +from dllm_eval.caching.cache import load_from_cache, save_to_cache +from dllm_eval.filters import build_filter_ensemble +from dllm_eval.prompts import get_prompt + + +ALL_OUTPUT_TYPES = [ + "loglikelihood", + "multiple_choice", + "loglikelihood_rolling", + "generate_until", +] + +eval_logger = logging.getLogger(__name__) + + +@dataclass +class TaskConfig(dict): + # task naming/registry + task: Optional[str] = None + task_alias: Optional[str] = None + tag: Optional[Union[str, list]] = None + # HF dataset options. + # which dataset to use, + # and what splits for what purpose + custom_dataset: Optional[Callable] = None + dataset_path: Optional[str] = None + dataset_name: Optional[str] = None + dataset_kwargs: Optional[dict] = None + training_split: Optional[str] = None + validation_split: Optional[str] = None + test_split: Optional[str] = None + fewshot_split: Optional[str] = ( + None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?) + ) + # formatting / prompting options. + # see docs/advanced_task_guide.md for more info + process_docs: Optional[Callable] = None + doc_to_text: Optional[Union[Callable, str]] = None + doc_to_target: Optional[Union[Callable, str]] = None + doc_to_image: Union[Callable, str] = None + doc_to_audio: Union[Callable, str] = None + unsafe_code: bool = False + doc_to_choice: Optional[Union[Callable, str, dict, list]] = None + process_results: Optional[Union[Callable, str]] = None + use_prompt: Optional[str] = None + description: str = "" + target_delimiter: str = " " + fewshot_delimiter: str = "\n\n" + fewshot_config: Optional[dict] = None + # runtime configuration options + num_fewshot: Optional[int] = None + # scoring options + metric_list: Optional[list] = None + output_type: OutputType = "generate_until" + generation_kwargs: Optional[dict] = None + repeats: int = 1 + filter_list: Optional[Union[str, list]] = None + should_decontaminate: bool = False + doc_to_decontamination_query: Optional[str] = None + gen_prefix: Optional[str] = None + metadata: Optional[dict] = ( + None # by default, not used in the code. allows for users to pass arbitrary info to tasks + ) + + def __post_init__(self) -> None: + if self.generation_kwargs is not None: + if self.output_type != "generate_until": + eval_logger.warning( + f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!" + ) + + if "temperature" in self.generation_kwargs: + self.generation_kwargs["temperature"] = float( + self.generation_kwargs["temperature"] + ) + + if "until" not in self.generation_kwargs: + eval_logger.warning( + f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}" + ) + self.generation_kwargs["until"] = [self.fewshot_delimiter] + else: + if self.output_type == "generate_until": + # ensure that we greedily generate in absence of explicit arguments otherwise + self.generation_kwargs = { + "until": ( + None + if self.fewshot_delimiter is None + else [self.fewshot_delimiter] + ), + "do_sample": False, + "temperature": 0, + } + eval_logger.warning( + f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" + ) + + def __getitem__(self, item): + return getattr(self, item) + + def __setitem__(self, item, value): + return setattr(self, item, value) + + def to_dict(self, keep_callable: bool = False) -> dict: + """dumps the current config as a dictionary object, as a printable format. + null fields will not be printed. + Used for dumping results alongside full task configuration + + :return: dict + A printable dictionary version of the TaskConfig object. + + # TODO: should any default value in the TaskConfig not be printed? + """ + cfg_dict = asdict(self) + # remove values that are `None` + for k, v in list(cfg_dict.items()): + if v is None: + cfg_dict.pop(k) + elif k == "metric_list": + for metric_dict in v: + for metric_key, metric_value in metric_dict.items(): + if callable(metric_value): + metric_dict[metric_key] = self.serialize_function( + metric_value, keep_callable=keep_callable + ) + cfg_dict[k] = v + elif callable(v): + cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) + return cfg_dict + + def serialize_function( + self, value: Union[Callable, str], keep_callable=False + ) -> Union[Callable, str]: + """Serializes a given function or string. + + If 'keep_callable' is True, the original callable is returned. + Otherwise, attempts to return the source code of the callable using 'getsource'. + """ + if keep_callable: + return value + else: + try: + return getsource(value) + except (TypeError, OSError): + return str(value) + + +class Task(abc.ABC): + """A task represents an entire benchmark including its dataset, problems, + answers, and evaluation methods. See BoolQ for a simple example implementation + + A `doc` can be any python object which represents one instance of evaluation. + This is usually a dictionary e.g. + {"question": ..., "answer": ...} or + {"question": ..., question, answer) + """ + + VERSION: Optional[Union[int, str]] = None + + # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub + # or a path to a custom `datasets` loading script. + DATASET_PATH: Optional[str] = None + + # The name of a subset within `DATASET_PATH`. + DATASET_NAME: Optional[str] = None + + OUTPUT_TYPE: Optional[OutputType] = None + + def __init__( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode: Optional[datasets.DownloadMode] = None, + config: Optional[Mapping] = None, # Union[dict, TaskConfig] + ) -> None: + """ + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.download(data_dir, cache_dir, download_mode) + self._training_docs: Optional[list] = None + self._fewshot_docs: Optional[list] = None + self._instances: Optional[List[Instance]] = None + + self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig() + + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + self.fewshot_rnd: Optional[random.Random] = ( + None # purposely induce errors in case of improper usage + ) + + def download( + self, + data_dir: Optional[str] = None, + cache_dir: Optional[str] = None, + download_mode=None, + ) -> None: + """Downloads and returns the task dataset. + Override this method to download the dataset from a custom API. + + :param data_dir: str + Stores the path to a local folder containing the `Task`'s data files. + Use this to specify the path to manually downloaded data (usually when + the dataset is not publicly accessible). + :param cache_dir: str + The directory to read/write the `Task` dataset. This follows the + HuggingFace `datasets` API with the default cache directory located at: + `~/.cache/huggingface/datasets` + NOTE: You can change the cache location globally for a given process + by setting the shell environment variable, `HF_DATASETS_CACHE`, + to another directory: + `export HF_DATASETS_CACHE="/path/to/another/directory"` + :param download_mode: datasets.DownloadMode + How to treat pre-existing `Task` downloads and data. + - `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS` + Reuse download and reuse dataset. + - `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS` + Reuse download with fresh dataset. + - `datasets.DownloadMode.FORCE_REDOWNLOAD` + Fresh download and fresh dataset. + """ + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + data_dir=data_dir, + cache_dir=cache_dir, + download_mode=download_mode, + ) + + @property + def config(self) -> TaskConfig: + """Returns the TaskConfig associated with this class.""" + return self._config + + @abc.abstractmethod + def has_training_docs(self): + """Whether the task has a training set""" + pass + + @abc.abstractmethod + def has_validation_docs(self): + """Whether the task has a validation set""" + pass + + @abc.abstractmethod + def has_test_docs(self): + """Whether the task has a test set""" + pass + + def training_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def validation_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def test_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + return [] + + def fewshot_docs(self) -> Iterable: + """ + :return: Iterable[obj] + A iterable of any object, that doc_to_text can handle + """ + if self.has_training_docs(): + return self.training_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + if self.config.get("num_fewshot", 0) > 0: + eval_logger.warning( + f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False" + ", using test_docs as fewshot_docs but this is not recommended." + ) + return self.test_docs() + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + @property + def instances(self) -> List[Instance]: + """After calling `task.build_all_requests()`, tasks + maintain a list of the dataset instances which will be evaluated. + """ + return self._instances + + def fewshot_examples(self, k, rnd): + if self._training_docs is None: + self._training_docs = list(self.training_docs()) + + return rnd.sample(self._training_docs, k) + + def doc_to_decontamination_query(self, doc): + raise NotImplementedError( + "Override doc_to_decontamination_query with document specific decontamination query." + ) + + @abc.abstractmethod + def doc_to_text(self, doc): + pass + + @abc.abstractmethod + def doc_to_target(self, doc): + pass + + # not an abstractmethod because not every language-only task has to implement this + def doc_to_image(self, doc): + raise NotImplementedError + + def doc_to_audio(self, doc): + raise NotImplementedError + + def doc_to_prefix(self, doc): + return "" + + def build_all_requests( + self, + *, + limit: Union[int, None] = None, + samples: Optional[List[int]] = None, + rank: int = 0, + world_size: int = 1, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + tokenizer_name: str = "", + ) -> None: + """Build a set of Instances for a task, and store them in task.instances""" + + # used with caching + og_limit = limit + + cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}" + cache_key += "-chat_template" if apply_chat_template else "" + cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else "" + cache_key += ( + f"-system_prompt_hash{utils.hash_string(system_instruction)}" + if system_instruction is not None + else "" + ) + cache_key += f"-tokenizer{tokenizer_name}" + + cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests) + + if cache_requests and cached_instances and not rewrite_requests_cache: + cached_instances = cached_instances[:limit] + + flattened_instances = [ + instance + for instance_group in cached_instances + for instance in instance_group + ] + + self._instances = flattened_instances + return + + eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...") + + instances = [] + + # process all documents when caching is specified for simplicity + if ( + cache_requests + and (not cached_instances or rewrite_requests_cache) + and limit is not None + ): + limit = None + + doc_id_docs = list( + self.doc_iterator( + rank=rank, limit=limit, samples=samples, world_size=world_size + ) + ) + + num_docs = len(doc_id_docs) + + for doc_id, doc in tqdm( + doc_id_docs, + total=num_docs, + ): + # sample fewshot context #TODO: need to offset doc_id by rank now! + fewshot_ctx = self.fewshot_context( + doc, + num_fewshot=0 + if self.config.num_fewshot is None + else self.config.num_fewshot, + system_instruction=system_instruction, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, + chat_template=chat_template, + gen_prefix=self.doc_to_prefix(doc), + ) + + # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute + inst = self.construct_requests( + doc=doc, + ctx=fewshot_ctx, + metadata=(self.config["task"], doc_id, self.config.repeats), + apply_chat_template=apply_chat_template, + chat_template=chat_template, + ) + + if not isinstance(inst, list): + inst = [inst] + + instances.append(inst) + + # now flatten, this is to allow slicing to work with pickles + + sliced_instances = instances[:og_limit] + + flattened_instances = [ + instance + for instance_group in sliced_instances + for instance in instance_group + ] + + self._instances = flattened_instances + + if len(self._instances) == 0: + raise ValueError("task.build_requests() did not find any docs!") + + if cache_requests and (not cached_instances or rewrite_requests_cache): + save_to_cache(file_name=cache_key, obj=instances) + + @abc.abstractmethod + def construct_requests(self, doc, ctx, **kwargs): + """Uses RequestFactory to construct Requests and returns an iterable of + Requests which will be sent to the LM. + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param ctx: str + The context string, generated by fewshot_context. This includes the natural + language description, as well as the few shot examples, and the question + part of the document for `doc`. + :param doc_idx: int + The index of a document within `self.test_docs()` or `self.validation_docs()`, + whichever is the main split used. + :param repeats: int + TODO: update this docstring + The number of times each instance in a dataset is inferred on. Defaults to 1, + can be increased for techniques like majority voting. + """ + pass + + @abc.abstractmethod + def process_results(self, doc, results): + """Take a single document and the LM results and evaluates, returning a + dict where keys are the names of submetrics and values are the values of + the metric for that one document + + :param doc: + The document as returned from training_docs, validation_docs, or test_docs. + :param results: + The results of the requests created in construct_requests. + """ + pass + + @abc.abstractmethod + def aggregation(self): + """ + :returns: {str: [metric_score] -> float} + A dictionary where keys are the names of submetrics and values are + functions that aggregate a list of metric scores + """ + pass + + @abc.abstractmethod + def higher_is_better(self): + """ + :returns: {str: bool} + A dictionary where keys are the names of submetrics and values are + whether a higher value of the submetric is better + """ + pass + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @classmethod + def count_bytes(cls, doc): + """Used for byte-level perplexity metrics in rolling loglikelihood""" + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc): + """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) + + @utils.positional_deprecated + def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs): + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param rnd: random.Random + The pseudo-random number generator used to randomly sample examples. + WARNING: This is currently a required arg although it's optionalized with a default `None`. + :param description: str + The task's description that will be prepended to the fewshot examples. + :returns: str + The fewshot context. + """ + if rnd is None: + if self.fewshot_rnd is not None: + rnd = self.fewshot_rnd + else: + raise ValueError( + "A `random.Random` generator argument must be provided to `rnd`" + ) + + description = description if description else "" + + if num_fewshot == 0: + labeled_examples = "" + else: + # for sets with no training docs, draw from other set *but ensure no overlap with current doc* + if self.has_training_docs(): + fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) + else: + if self._fewshot_docs is None: + self._fewshot_docs = list( + self.validation_docs() + if self.has_validation_docs() + else self.test_docs() + ) + + fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = ( + "\n\n".join( + [ + self.doc_to_text(doc) + self.doc_to_target(doc) + for doc in fewshotex + ] + ) + + "\n\n" + ) + + example = self.doc_to_text(doc) + return description + labeled_examples + example + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def dump_config(self) -> dict: + """Returns the config as a dictionary.""" + # TODO: this should only return the overrides applied to a non-YAML task's configuration. + # (num_fewshot) + return self.config.to_dict() + + def set_config(self, key: str, value: Any, update: bool = False) -> None: + """Set or update the configuration for a given key.""" + if key is None: + raise ValueError("Key must be provided.") + + if update: + current_value = getattr(self._config, key, {}) + if not isinstance(current_value, dict): + raise TypeError( + f"Expected a dict for key '{key}', got {type(current_value).__name__} instead." + ) + current_value.update(value) + else: + setattr(self._config, key, value) + + def override_metric(self, metric_name: str) -> None: + """ + Override the default metrics used for evaluation with custom metrics. + + Parameters: + - metric_name (str): The name of the custom metric to override. Should be registered in api.metrics. + """ + ( + self._metric_fn_list, + self._aggregation_list, + self._metric_fn_kwargs, + self._higher_is_better, + ) = ({}, {}, {}, {}) + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._aggregation_list[metric_name] = get_metric_aggregation(metric_name) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + self._metric_fn_kwargs[metric_name] = {} + if not isinstance(self, ConfigurableTask): + self.process_results = lambda x, y: {metric_name: get_metric(metric_name)} + self.aggregation = lambda: { + metric_name: get_metric_aggregation(metric_name) + } + setattr(self._config, "metric_list", [{"metric": metric_name}]) + setattr(self._config, "process_results", None) + + def set_fewshot_seed(self, seed: Optional[int] = None) -> None: + self.fewshot_rnd = random.Random(seed) + if hasattr(self, "sampler"): + self.sampler.rnd = self.fewshot_rnd + + @property + def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: + if self.has_test_docs(): + return self.test_docs() + elif self.has_validation_docs(): + return self.validation_docs() + else: + raise ValueError( + f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" + ) + + def doc_iterator( + self, + *, + rank: int = 0, + limit: Union[int, None] = None, + world_size: int = 1, + samples: Optional[List[int]] = None, + ) -> Iterator[Tuple[int, Any]]: + if samples: + n = len(self.eval_docs) + assert all([e < n for e in samples]), ( + f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}." + ) + eval_logger.info( + f"{self.config.task}: Evaluating on {len(samples)} examples" + ) + doc_iterator = utils.create_iterator( + enumerate(x for i, x in enumerate(self.eval_docs) if i in samples), + rank=int(rank), + limit=None, # limit does not matter here since we are selecting samples directly + world_size=int(world_size), + ) + else: + limit = int(limit) if limit else None + doc_iterator = utils.create_iterator( + enumerate(self.eval_docs), + rank=int(rank), + limit=limit, + world_size=int(world_size), + ) + return doc_iterator + + +class ConfigurableTask(Task): + VERSION = "Yaml" + OUTPUT_TYPE = None + CONFIG = None + + def __init__( + self, + data_dir=None, + cache_dir=None, + download_mode=None, + config: Optional[dict] = None, + ) -> None: # TODO no super() call here + # Get pre-configured attributes + self._config = self.CONFIG + + # Use new configurations if there was no preconfiguration + if self.config is None: + self._config = TaskConfig(**config) + # Overwrite configs + else: + if config is not None: + self._config.__dict__.update(config) + + if self.config is None: + raise ValueError( + "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" + ) + + if isinstance(self.config.metadata, dict): + if "version" in self.config.metadata: + self.VERSION = self.config.metadata["version"] + + if self.config.output_type is not None: + if self.config.output_type not in ALL_OUTPUT_TYPES: + raise ValueError( + f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'" + ) + self.OUTPUT_TYPE = self.config.output_type + + if self.config.doc_to_image is not None: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.doc_to_audio: + # mark the task as requiring multimodality. + self.MULTIMODAL = True + + if self.config.unsafe_code is not False: + self.UNSAFE_CODE = True + + if self.config.dataset_path is not None: + self.DATASET_PATH = self.config.dataset_path + + if self.config.dataset_name is not None: + self.DATASET_NAME = self.config.dataset_name + + self._metric_fn_list = {} + self._metric_fn_kwargs = {} + self._aggregation_list = {} + self._higher_is_better = {} + + if self.config.metric_list is None: + # TODO: handle this in TaskConfig.__post_init__ ? + _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type] + + for metric_name in _metric_list: + self._metric_fn_list[metric_name] = get_metric(metric_name) + self._metric_fn_kwargs[metric_name] = {} + self._aggregation_list[metric_name] = get_metric_aggregation( + metric_name + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + else: + for metric_config in self.config.metric_list: + if "metric" not in metric_config: + raise ValueError( + "'metric' key not provided for an entry in 'metric_list', must be specified!" + ) + metric_name = metric_config["metric"] + kwargs = { + key: metric_config[key] + for key in metric_config + if key + not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"] + } + hf_evaluate_metric = ( + "hf_evaluate" in metric_config + and metric_config["hf_evaluate"] is True + ) + + if self.config.process_results is not None: + self._metric_fn_list[metric_name] = None + self._metric_fn_kwargs[metric_name] = {} + elif callable(metric_name): + metric_fn = metric_name.__call__ + metric_name = metric_name.__name__ + self._metric_fn_list[metric_name] = metric_fn + self._metric_fn_kwargs[metric_name] = kwargs + else: + self._metric_fn_list[metric_name] = get_metric( + metric_name, hf_evaluate_metric + ) + self._metric_fn_kwargs[metric_name] = kwargs + + if "aggregation" in metric_config: + agg_name = metric_config["aggregation"] + if isinstance(agg_name, str): + self._aggregation_list[metric_name] = get_aggregation(agg_name) + elif callable(agg_name): # noqa: E721 + self._aggregation_list[metric_name] = metric_config[ + "aggregation" + ] + else: + INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} + metric_agg = get_metric_aggregation(metric_name) + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. " + f"using default " + f"aggregation={INV_AGG_REGISTRY[metric_agg]}" + ) + self._aggregation_list[metric_name] = metric_agg + + if "higher_is_better" in metric_config: + self._higher_is_better[metric_name] = metric_config[ + "higher_is_better" + ] + else: + eval_logger.warning( + f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. " + f"using default " + f"higher_is_better={is_higher_better(metric_name)}" + ) + self._higher_is_better[metric_name] = is_higher_better(metric_name) + + self.download(self.config.dataset_kwargs) + self._training_docs = None + self._fewshot_docs = None + + if self.config.filter_list is not None: + self._filters = [] + for filter_config in self.config.filter_list: + filter_name = filter_config["name"] + filter_functions = filter_config["filter"] + components = [] + for function in filter_functions: + kwargs = { + key: function[key] for key in function if key != "function" + } + components.append([function["function"], kwargs]) + filter_pipeline = build_filter_ensemble(filter_name, components) + self._filters.append(filter_pipeline) + else: + # TODO: handle repeats in a more general way rather than just discarding + eval_logger.debug( + "No custom filters defined. Using default 'take_first' filter for handling repeats." + ) + self._filters = [build_filter_ensemble("none", [["take_first", None]])] + + if self.config.use_prompt is not None: + eval_logger.info(f"loading prompt {self.config.use_prompt}") + self.prompt = get_prompt( + self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME + ) + else: + self.prompt = None + + if self.fewshot_docs() is not None: + self.fewshot_rnd = ( + random.Random() + ) # setting with no seed, to be overridden at a later time + config_sampler: Union[str, Callable] = ( + self.config.fewshot_config.get("sampler", "default") + if self.config.fewshot_config + else "default" + ) + if isinstance(config_sampler, str): + self.sampler = samplers.get_sampler(config_sampler)( + list(self.fewshot_docs()), self, rnd=self.fewshot_rnd + ) + elif callable(config_sampler) and issubclass( + config_sampler, samplers.ContextSampler + ): + self.sampler = config_sampler( + docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd + ) + else: + raise TypeError( + f"fewshot_config.sampler should be a string or callable of ContextSampler type, " + f"not {type(config_sampler)}" + ) + + self.task_docs = self.eval_docs + + # Test One Doc + self.features = list(self.task_docs.features.keys()) + self.multiple_input = 0 + self.multiple_target = 0 + test_doc = self.task_docs[0] + test_text = self.doc_to_text(test_doc) + test_target = self.doc_to_target(test_doc) + + if self.config.doc_to_choice is not None: + test_choice = self.doc_to_choice(test_doc) + if not isinstance(test_choice, list): + eval_logger.error("doc_to_choice must return list") + else: + num_choice = len(test_choice) + + if isinstance(test_text, int): + eval_logger.debug( + "doc_to_text returned an int. Assuming multiple inputs." + ) + self.multiple_input = num_choice + else: + test_choice = None + + if isinstance(test_target, list): + eval_logger.debug( + "doc_to_target returned a list. Assuming multiple targets." + ) + self.multiple_target = len(test_target) + else: + if (isinstance(test_target, int)) and (test_choice is not None): + test_target = test_choice[test_target] + else: + test_target = str(test_target) + + if test_choice is not None: + check_choices = test_choice + else: + check_choices = [test_target] + if self.config.doc_to_choice is not None: + for choice in check_choices: + choice_has_whitespace = True if choice[0].isspace() else False + delimiter_has_whitespace = ( + True + if self.config.target_delimiter.rstrip() + != self.config.target_delimiter + else False + ) + + if delimiter_has_whitespace and choice_has_whitespace: + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace' + ) + elif (not delimiter_has_whitespace) and (not choice_has_whitespace): + eval_logger.debug( + f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' + ) + + def download( + self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs + ) -> None: + if isinstance(self.config.custom_dataset, Callable): + eval_logger.warning( + f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager." + + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme." + ) + self.dataset = self.config.custom_dataset( + **(self.config.metadata or {}), **(self.config.dataset_kwargs or {}) + ) + else: + self.dataset = datasets.load_dataset( + path=self.DATASET_PATH, + name=self.DATASET_NAME, + **dataset_kwargs if dataset_kwargs is not None else {}, + ) + + def has_training_docs(self) -> bool: + if self.config.training_split is not None: + return True + else: + return False + + def has_validation_docs(self) -> bool: + if self.config.validation_split is not None: + return True + else: + return False + + def has_test_docs(self) -> bool: + if self.config.test_split is not None: + return True + else: + return False + + def training_docs(self) -> datasets.Dataset: + if self.has_training_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.training_split] + ) + return self.dataset[self.config.training_split] + + def validation_docs(self) -> datasets.Dataset: + if self.has_validation_docs(): + if self.config.process_docs is not None: + return self.config.process_docs( + self.dataset[self.config.validation_split] + ) + return self.dataset[self.config.validation_split] + + def test_docs(self) -> datasets.Dataset: + if self.has_test_docs(): + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.test_split]) + return self.dataset[self.config.test_split] + + def fewshot_docs(self): + if self.config.fewshot_split is not None: + if self.config.process_docs is not None: + return self.config.process_docs(self.dataset[self.config.fewshot_split]) + return self.dataset[self.config.fewshot_split] + elif ( + self.config.fewshot_config is not None + and self.config.fewshot_config.get("samples", None) is not None + ): + if isinstance(self.config.fewshot_config["samples"], list): + return self.config.fewshot_config["samples"] + elif callable(self.config.fewshot_config["samples"]): + return self.config.fewshot_config["samples"]() + else: + raise Exception( + "`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list." + ) + else: + if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0): + eval_logger.warning( + f"[Task: {self.config.task}] " + "num_fewshot > 0 but fewshot_split is None. " + "using preconfigured rule." + ) + return super().fewshot_docs() + + @staticmethod + def append_target_question( + labeled_examples: List[Dict[str, str]], + question: str, + fewshot_as_multiturn: bool = False, + gen_prefix: Optional[str] = None, + ) -> None: + """Adds a target question to the labeled examples list. + If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. + Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant. + """ + if not fewshot_as_multiturn: + # if no messages or last message is system, append as new user entry + if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system": + labeled_examples.append({"role": "user", "content": question}) + # if last message is user, append to it to avoid two user messages in a row + else: + labeled_examples[-1]["content"] += question + else: + # if fewshot_as_multiturn is True, append as next user entry (last is always assistant) + labeled_examples.append({"role": "user", "content": question}) + if gen_prefix: + labeled_examples.append({"role": "assistant", "content": gen_prefix}) + + @utils.positional_deprecated + def fewshot_context( + self, + doc: dict, + num_fewshot: int, + system_instruction: Optional[str] = None, + apply_chat_template: bool = False, + fewshot_as_multiturn: bool = False, + chat_template: Optional[Callable] = None, + gen_prefix: Optional[str] = None, + ) -> Union[str, List[str]]: + """Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param system_instruction: str + System instruction to be applied to the prompt. + :param apply_chat_template: bool + Whether to apply the chat template to the fewshot context. + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param chat_template: + callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string. + :param gen_prefix: + String to append after the <|assistant|> token. + :returns: str + The fewshot context. + """ + if apply_chat_template: + labeled_examples = [] + else: + labeled_examples = "" + + # get task description + if description := self.config.description: + description = utils.apply_template(self.config.description, doc) + + # create system prompt based on the provided system instruction and description + if system_instruction is not None and description: + system_prompt = ( + f"{system_instruction}{self.sampler.fewshot_delimiter}{description}" + ) + elif system_instruction is not None: + system_prompt = system_instruction + elif description: + system_prompt = description + else: + system_prompt = "" + + # add system prompt if specified + if system_prompt: + if apply_chat_template: + labeled_examples.append({"role": "system", "content": system_prompt}) + else: + labeled_examples = system_prompt + # if few-shot - append examples after the system prompt + if num_fewshot > 0: + if apply_chat_template: + labeled_examples.extend( + self.sampler.get_chat_context( + doc, + num_fewshot, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + ) + else: + labeled_examples += self.sampler.get_context( + doc, num_fewshot, gen_prefix=gen_prefix + ) + + example = self.doc_to_text(doc) + if apply_chat_template: + if self.multiple_input: + # TODO: append prefill? + if not labeled_examples: + return "" + return chat_template(labeled_examples) + if isinstance(example, str): + self.append_target_question( + labeled_examples, + example, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # for loglikelihood create a list of questions with appended choices + elif isinstance(example, list): + labeled_examples_list = [] + # copy chat history for each example and append the answer + for ex in example: + chat = deepcopy(labeled_examples) + self.append_target_question( + chat, + ex, + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # TODO: append prefill? + labeled_examples_list.append( + chat_template( + chat, + add_generation_prompt=False if gen_prefix else True, + ) + ) + return labeled_examples_list + # if example is an integer, append the choice or convert to string + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + self.append_target_question( + labeled_examples, + choices[example], + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + else: + self.append_target_question( + labeled_examples, + str(example), + fewshot_as_multiturn, + gen_prefix=gen_prefix, + ) + # return lm.apply_chat_template(labeled_examples) + return chat_template( + labeled_examples, + add_generation_prompt=False if gen_prefix else True, + ) + else: + prefix = ( + self.config.target_delimiter + gen_prefix + if gen_prefix is not None + else "" + ) + if self.multiple_input: + return labeled_examples + if isinstance(example, str): + return labeled_examples + example + prefix + elif isinstance(example, list): + return [labeled_examples + ex + prefix for ex in example] + elif isinstance(example, int): + if self.config.doc_to_choice is not None: + choices = self.doc_to_choice(doc) + return labeled_examples + choices[example] + prefix + else: + return labeled_examples + str(example) + prefix + + def apply_filters(self) -> Optional[List[Instance]]: + """Iterates over FilterEnsembles and applies them to instances""" + if hasattr(self, "_filters"): + for f in self._filters: + f.apply(self._instances) + else: + eval_logger.warning("No filter defined, passing through instances") + return self._instances + + def should_decontaminate(self): + return self.config.should_decontaminate + + def doc_to_decontamination_query(self, doc: dict): + if self.config.should_decontaminate: + if self.config.doc_to_decontamination_query is None: + return self.doc_to_text(doc) + else: + doc_to_decontamination_query = self.config.doc_to_decontamination_query + if doc_to_decontamination_query in self.features: + return doc[doc_to_decontamination_query] + elif callable(doc_to_decontamination_query): + return doc_to_decontamination_query(doc) + else: + return ast.literal_eval( + utils.apply_template( + self.config.doc_to_decontamination_query, doc + ) + ) + + def _process_doc(self, doc: dict) -> dict: + """ + Override this to process (detokenize, strip, replace, etc.) individual + documents. This can be used in a map over documents of a data split. + E.g. `map(self._process_doc, self.dataset["validation"])` + + :return: dict + The processed version of the specified `doc`. + """ + return doc + + def doc_to_text(self, doc, doc_to_text=None): + if self.prompt is not None: + doc_to_text = self.prompt + elif doc_to_text is not None: + doc_to_text = doc_to_text + else: + doc_to_text = self.config.doc_to_text + + if isinstance(doc_to_text, int): + return doc_to_text + elif isinstance(doc_to_text, str): + if doc_to_text in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_text]] + # else: + return doc[doc_to_text] + else: + text_string = utils.apply_template(doc_to_text, doc) + if text_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(text_string) + else: + return text_string + elif callable(doc_to_text): + return doc_to_text(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_text, "apply"): + applied_prompt = doc_to_text.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[0] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + print(type(doc_to_text)) + raise TypeError + + def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]: + if self.prompt is not None: + doc_to_target = self.prompt + elif doc_to_target is not None: + doc_to_target = doc_to_target + else: + doc_to_target = self.config.doc_to_target + + if isinstance(doc_to_target, int): + return doc_to_target + elif isinstance(doc_to_target, str): + if doc_to_target in self.features: + # if self.config.doc_to_choice is not None: + # return self.doc_to_choice(doc)[doc[doc_to_target]] + # else: + return doc[doc_to_target] + else: + target_string = utils.apply_template(doc_to_target, doc) + if target_string.isdigit() and self._config.doc_to_choice is not None: + return ast.literal_eval(target_string) + elif ( + len(target_string) >= 2 + and (target_string[0] == "[") + and (target_string[-1] == "]") + ): + try: + return ast.literal_eval(target_string) + except (SyntaxError, ValueError): + return target_string + else: + return target_string + elif isinstance(doc_to_target, list): + return doc_to_target + elif callable(doc_to_target): + return doc_to_target(doc) + # Used when applying a Promptsource template + elif hasattr(doc_to_target, "apply"): + applied_prompt = doc_to_target.apply(doc) + if len(applied_prompt) == 2: + return applied_prompt[1] + else: + eval_logger.warning("Applied prompt returns empty string") + return self.config.fewshot_delimiter + else: + raise TypeError + + def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]: + if self.prompt is not None: + doc_to_choice = self.prompt + elif doc_to_choice is not None: + doc_to_choice = doc_to_choice + elif self.config.doc_to_choice is None: + eval_logger.error("doc_to_choice was called but not set in config") + else: + doc_to_choice = self.config.doc_to_choice + + if isinstance(doc_to_choice, str): + if doc_to_choice in self.features: + return doc[doc_to_choice] + else: + return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) + elif isinstance(doc_to_choice, list): + return doc_to_choice + elif isinstance(doc_to_choice, dict): + return list(doc_to_choice.values()) + elif callable(doc_to_choice): + return doc_to_choice(doc) + elif hasattr(doc_to_choice, "get_answer_choices_list"): + return doc_to_choice.get_answer_choices_list(doc) + else: + raise TypeError + + def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: + if doc_to_image is not None: + doc_to_image = doc_to_image + elif self.config.doc_to_image is not None: + doc_to_image = self.config.doc_to_image + else: + return None + + if isinstance(doc_to_image, list): + image_feature = [ + self.doc_to_image(doc, feature) for feature in doc_to_image + ] + return [feature for feature in image_feature if feature is not None] + elif isinstance(doc_to_image, str): + if doc_to_image in self.features: + return doc[doc_to_image] + else: + return ast.literal_eval(utils.apply_template(doc_to_image, doc)) + elif callable(doc_to_image): + return doc_to_image(doc) + else: + return None + + def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]: + if doc_to_audio is not None: + doc_to_audio = doc_to_audio + elif self.config.doc_to_audio is not None: + doc_to_audio = self.config.doc_to_audio + else: + return None + + if isinstance(doc_to_audio, list): + audio_feature = [ + self.doc_to_audio(doc, feature) for feature in doc_to_audio + ] + return [feature for feature in audio_feature if feature is not None] + elif isinstance(doc_to_audio, str): + if doc_to_audio in self.features: + return doc[doc_to_audio] + else: + return ast.literal_eval(utils.apply_template(doc_to_audio, doc)) + elif callable(doc_to_audio): + return doc_to_audio(doc) + else: + return None + + def doc_to_prefix(self, doc): + if (gen_prefix := self.config.gen_prefix) is not None: + if gen_prefix in self.features: + return doc[gen_prefix] + else: + return utils.apply_template(gen_prefix, doc) + return None + + def construct_requests( + self, doc: dict, ctx: str, **kwargs + ) -> Union[List[Instance], Instance]: + apply_chat_template = kwargs.pop("apply_chat_template", False) + chat_template: Callable | None = kwargs.pop("chat_template", None) + + aux_arguments = None + + if self.OUTPUT_TYPE == "loglikelihood": + arguments = (ctx, self.doc_to_target(doc)) + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + arguments = (self.doc_to_target(doc),) + elif self.OUTPUT_TYPE == "multiple_choice": + choices = self.doc_to_choice(doc) + target_delimiter = self.config.target_delimiter + if apply_chat_template: + target_delimiter = "" + if self.multiple_input: + # If there are multiple inputs, choices are placed in the ctx + # apply chat_template to choices if apply_chat_template + cont = self.doc_to_target(doc) + + arguments = [ + ( + ctx + + ( + chat_template([{"role": "user", "content": choice}]) + if apply_chat_template + else choice + ), + f"{target_delimiter}{cont}", + ) + for choice in choices + ] + else: + # Otherwise they are placed in the continuation + arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] + + # TODO: we should raise a warning telling users this will at most ~2x runtime. + if "acc_mutual_info" in self._metric_fn_list.keys(): + # if we are calculating multiple choice accuracy + # using mutual information instead of raw loglikelihood as metric, need unconditional lls. + + # here mutual info refers to calculating + # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) + # in other words normalizing by subtracting the unconditional logprob of each choice. + # TODO: should these be strided? will have to modify the processing in process_results if so + aux_arguments = [ + ("", f"{target_delimiter}{choice}") for choice in choices + ] + + arguments.extend(aux_arguments) + + elif self.OUTPUT_TYPE == "generate_until": + arguments = (ctx, deepcopy(self.config.generation_kwargs)) + + multimodal_arg = {} + if ( + self.config.doc_to_image + ): # TODO: ensure that non-multimodal tasks aren't getting visual args + multimodal_arg = { + **multimodal_arg, + **{"visual": self.doc_to_image(doc)}, + } + + if ( + self.config.doc_to_audio + ): # TODO: ensure that non-multimodal tasks aren't getting audio args + multimodal_arg = { + **multimodal_arg, + **{"audio": self.doc_to_audio(doc)}, + } + + if bool(multimodal_arg): + if isinstance(arguments, list): + arguments = [arg + (multimodal_arg,) for arg in arguments] + else: + arguments = arguments + (multimodal_arg,) + + if self.OUTPUT_TYPE == "multiple_choice": + request_list = [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=arg, + idx=i, + **kwargs, + ) + for i, arg in enumerate(arguments) + ] + + return request_list + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=arguments, + idx=0, + **kwargs, + ) + + def process_results(self, doc, results): + if callable(self.config.process_results): + return self.config.process_results(doc, results) + + result_dict = {} + use_metric = list(self._metric_fn_list.keys()) + if self.OUTPUT_TYPE == "loglikelihood": + results = results[0] + ll, is_greedy = results + return { + **({"perplexity": ll} if "perplexity" in use_metric else {}), + **({"acc": int(is_greedy)} if "acc" in use_metric else {}), + } + elif self.OUTPUT_TYPE == "loglikelihood_rolling": + (loglikelihood,) = results + _words = self.count_words(self.doc_to_target(doc)) + _bytes = self.count_bytes(self.doc_to_target(doc)) + return { + **( + {"word_perplexity": (loglikelihood, _words)} + if "word_perplexity" in use_metric + else {} + ), + **( + {"byte_perplexity": (loglikelihood, _bytes)} + if "byte_perplexity" in use_metric + else {} + ), + **( + {"bits_per_byte": (loglikelihood, _bytes)} + if "bits_per_byte" in use_metric + else {} + ), + } + elif self.OUTPUT_TYPE == "multiple_choice": + lls, is_greedy = zip(*results) + + # retrieve choices in List[str] form, to compute choice lengths, etc. + choices = self.doc_to_choice(doc) + completion_len = np.array([float(len(i)) for i in choices]) + + if ( + 2 * len(choices) == len(lls) + and "acc_mutual_info" in self._metric_fn_list.keys() + ): + # then we are doing mutual info. + # this stores the "dryrun" / unconditional answer loglikelihoods + # as we extend the args list with unconditional ("", continuation) pairs + lls_unconditional = lls[len(choices) :] + if len(lls_unconditional) != len(choices): + raise ValueError + # and this stores our "regular" conditional loglikelihoods + lls = lls[: len(choices)] + + pred = np.argmax(lls) + pred_norm = np.argmax(lls / completion_len) + + if self.multiple_input: + gold = self.doc_to_text(doc) + else: + gold = self.doc_to_target(doc) + + gold_index_error = False + if isinstance(gold, list): + gold = [i if i < len(choices) else -100 for i in gold] + if -100 in gold: + gold_index_error = True + else: + if isinstance(gold, int): + gold = gold if gold < len(choices) else -100 + elif isinstance(gold, str): + gold = choices.index(gold) if gold in choices else -100 + + if gold == -100: + gold_index_error = True + + if gold_index_error: + eval_logger.warning( + f"Label index was not in within range of available choices," + f"Sample:\n\n{doc}\n\n" + ) + + if self.multiple_target: + acc = 1.0 if pred in gold else 0.0 + acc_norm = 1.0 if pred_norm in gold else 0.0 + exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) + else: + acc = 1.0 if pred == gold else 0.0 + acc_norm = 1.0 if pred_norm == gold else 0.0 + # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly + exact_match = int(is_greedy[gold]) if gold != -100 else 0 + + prob_norm = utils.softmax(lls) + + # TODO use keyword arguments to the metric? + # gold, pred, norm stuff, the original lls, + result_dict = { + **({"acc": acc} if "acc" in use_metric else {}), + **({"f1": (gold, pred)} if "f1" in use_metric else {}), + **({"mcc": (gold, pred)} if "mcc" in use_metric else {}), + **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), + **({"exact_match": exact_match} if "exact_match" in use_metric else {}), + **( + {"brier_score": (gold, prob_norm)} + if "brier_score" in use_metric + else {} + ), + } + + if "acc_mutual_info" in use_metric: + lls_mutual_info = [ + ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional) + ] + acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 + result_dict["acc_mutual_info"] = acc_mutual_info + + elif self.OUTPUT_TYPE == "generate_until": + gold = self.doc_to_target(doc) + result = results[0] + if self.config.doc_to_choice is not None: + # If you set doc_to_choice, + # it assumes that doc_to_target returns a number. + choices = self.doc_to_choice(doc) + gold = choices[gold] + # we expect multiple_targets to be a list. + elif self.multiple_target: + gold = list(gold) + # TODO: handle this better + elif type(gold) is not type(result) and not ( + "bypass" in self._metric_fn_list.keys() or isinstance(result, list) + ): + # cast gold to the same type as result + gold = type(result)(gold) + + for metric in self._metric_fn_list.keys(): + if self.multiple_target: + # in the case where we have multiple targets, + # return true if any are true + # TODO: this may break for multipLe_target, non zero-or-1 metrics + scores = [] + if not isinstance(gold, list): + # sometimes, a multiple_target dataset has exceptions where one doc has only one string answer + # print(gold) + gold = [gold] + if metric == "exact_match": + result = [result for _ in range(len(gold))] + scores = self._metric_fn_list[metric]( + references=gold, + predictions=result, + **self._metric_fn_kwargs[metric], + )[metric] + result_score = 1.0 if scores > 0.0 else 0.0 + else: + for gold_option in gold: + try: + result_score = self._metric_fn_list[metric]( + references=[gold_option], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except ( + TypeError + ): # TODO: this is hacky and I don't want to do it + result_score = self._metric_fn_list[metric]( + [gold_option, result] + ) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + result_score = result_score[metric] + scores.append(result_score) + if any(scores): + result_score = 1.0 + else: + result_score = 0.0 + else: + try: + result_score = self._metric_fn_list[metric]( + references=[gold], + predictions=[result], + **self._metric_fn_kwargs[metric], + ) + except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics + result_score = self._metric_fn_list[metric]([gold, result]) + if isinstance(result_score, dict): + # TODO: this handles the case where HF evaluate returns a dict. + # This allows for multiple metrics to be returned from the same function + for k, v in result_score.items(): + result_dict[k] = v + else: + result_dict[metric] = result_score + else: + raise ValueError( + f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", + "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'", + ) + + return result_dict + + def aggregation(self) -> dict: + return self._aggregation_list + + def higher_is_better(self) -> dict: + return self._higher_is_better + + def get_config(self, key: str) -> Any: + return getattr(self._config, key, None) + + @property + def task_name(self) -> Any: + return getattr(self.config, "task", None) + + def __repr__(self): + return ( + f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," + f"output_type={self.OUTPUT_TYPE}," + f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," + f"num_samples={len(self.eval_docs)})" + ) + + +class MultipleChoiceTask(Task): + OUTPUT_TYPE = "loglikelihood" + + def doc_to_target(self, doc: dict) -> str: + return " " + doc["choices"][doc["gold"]] + + def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]: + # TODO: add mutual info here? + return [ + Instance( + request_type="loglikelihood", + doc=doc, + arguments=(ctx, " {}".format(choice)), + idx=i, + **kwargs, + ) + for i, choice in enumerate(doc["choices"]) + ] + + def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict: + results = [ + res[0] for res in results + ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? + gold = doc["gold"] + + acc = 1.0 if np.argmax(results) == gold else 0.0 + completion_len = np.array([float(len(i)) for i in doc["choices"]]) + acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0 + + return { + "acc": acc, + "acc_norm": acc_norm, + } + + def higher_is_better(self) -> dict: + return { + "acc": True, + "acc_norm": True, + } + + def aggregation(self) -> dict: + return { + "acc": mean, + "acc_norm": mean, + } + + +class PerplexityTask(Task): + OUTPUT_TYPE = "loglikelihood_rolling" + + def has_training_docs(self) -> bool: + return False + + def fewshot_examples(self, k: int, rnd) -> List: + if k != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + return [] + + def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]: + if num_fewshot != 0: + raise ValueError( + "The number of fewshot examples must be 0 for perplexity tasks." + ) + + return "" + + def higher_is_better(self) -> dict: + return { + "word_perplexity": False, + "byte_perplexity": False, + "bits_per_byte": False, + } + + def doc_to_decontamination_query(self, doc): + return doc + + def doc_to_text(self, doc) -> str: + return "" + + def doc_to_target(self, doc): + return doc + + def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs): + if bool(ctx): + raise ValueError + + return Instance( + request_type=self.OUTPUT_TYPE, + doc=doc, + arguments=(self.doc_to_target(doc),), + idx=0, + **kwargs, + ) + + def process_results(self, doc: dict, results: Tuple[float]) -> dict: + (loglikelihood,) = results + words = self.count_words(self.doc_to_target(doc)) + bytes_ = self.count_bytes(self.doc_to_target(doc)) + return { + "word_perplexity": (loglikelihood, words), + "byte_perplexity": (loglikelihood, bytes_), + "bits_per_byte": (loglikelihood, bytes_), + } + + def aggregation(self) -> dict: + return { + "word_perplexity": weighted_perplexity, + "byte_perplexity": weighted_perplexity, + "bits_per_byte": bits_per_byte, + } + + @classmethod + def count_bytes(cls, doc) -> int: + return len(doc.encode("utf-8")) + + @classmethod + def count_words(cls, doc) -> int: + """Downstream tasks with custom word boundaries should override this!""" + return len(re.split(r"\s+", doc)) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/caching/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/caching/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/caching/cache.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/caching/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d293b0ff8b1ebac186f5ac078cdb49227562db --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/caching/cache.py @@ -0,0 +1,59 @@ +import hashlib +import logging +import os + +import dill + + +eval_logger = logging.getLogger(__name__) + + +MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) + +OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH") + + +PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache" + +# This should be sufficient for uniqueness +HASH_INPUT = "EleutherAI-lm-evaluation-harness" + +HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest() + +FILE_SUFFIX = f".{HASH_PREFIX}.pickle" + + +def load_from_cache(file_name: str, cache: bool = False): + if not cache: + return + try: + path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + with open(path, "rb") as file: + cached_task_dict = dill.loads(file.read()) + return cached_task_dict + + except Exception: + eval_logger.debug(f"{file_name} is not cached, generating...") + pass + + +def save_to_cache(file_name, obj): + if not os.path.exists(PATH): + os.mkdir(PATH) + + file_path = f"{PATH}/{file_name}{FILE_SUFFIX}" + + eval_logger.debug(f"Saving {file_path} to cache...") + with open(file_path, "wb") as file: + file.write(dill.dumps(obj)) + + +# NOTE the "key" param is to allow for flexibility +def delete_cache(key: str = ""): + files = os.listdir(PATH) + + for file in files: + if file.startswith(key) and file.endswith(FILE_SUFFIX): + file_path = f"{PATH}/{file}" + os.unlink(file_path) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/archiver.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/archiver.py new file mode 100644 index 0000000000000000000000000000000000000000..c132232116c2ae5f5ab1dc3a2a0afc0dbd4ef1bd --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/archiver.py @@ -0,0 +1,174 @@ +import datetime +import io +import json +import mmap +import os +from pathlib import Path +from typing import Any + +import jsonlines +import tqdm +import zstandard + + +def json_serial(obj: Any) -> str: + """JSON serializer for objects not serializable by default json code""" + + if isinstance(obj, (datetime.datetime,)): + return obj.isoformat() + raise TypeError("Type %s not serializable" % type(obj)) + + +# Modified version of lm_dataformat Archive for single file. +class Archive: + def __init__(self, file_path: str, compression_level: int = 3) -> None: + self.file_path = file_path + dir_name = os.path.dirname(file_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + self.fh = open(self.file_path, "wb") + self.cctx = zstandard.ZstdCompressor(level=compression_level) + self.compressor = self.cctx.stream_writer(self.fh) + + def add_data(self, data, meta=None) -> None: + if meta is None: + meta = {} + self.compressor.write( + json.dumps({"text": data, "meta": meta}, default=json_serial).encode( + "UTF-8" + ) + + b"\n" + ) + + def commit(self) -> None: + self.compressor.flush(zstandard.FLUSH_FRAME) + self.fh.flush() + self.fh.close() + + +# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. +class Reader: + def __init__(self) -> None: + pass + + def read( + self, + file, + get_meta: bool = False, + autojoin_paragraphs: bool = True, + para_joiner: str = "\n\n", + ): + with open(file, "rb") as fh: + self.fh = fh + cctx = zstandard.ZstdDecompressor() + reader = io.BufferedReader(cctx.stream_reader(fh)) + rdr = jsonlines.Reader(reader) + for ob in rdr: + # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility. + if isinstance(ob, str): + assert not get_meta + yield ob + continue + + text = ob["text"] + + if autojoin_paragraphs and isinstance(text, list): + text = para_joiner.join(text) + + if get_meta: + yield text, (ob["meta"] if "meta" in ob else {}) + else: + yield text + + +class TextArchive: + def __init__(self, file_path, mode: str = "rb+") -> None: + self.file_path = file_path + dir_name = os.path.dirname(file_path) + if dir_name: + os.makedirs(dir_name, exist_ok=True) + + if not os.path.exists(file_path): + Path(file_path).touch() + + self.fh = open(self.file_path, mode) + + def add_data(self, data) -> None: + self.fh.write(data.encode("UTF-8") + b"\n") + + def commit(self) -> None: + self.fh.flush() + self.fh.close() + + +class TextReader: + def __init__(self, file_path) -> None: + self.file_path = file_path + + # Optimized mmap read with infrequent tqdm updates to maintain speed + # Tested up to 250MB/s. + def read_tqdm(self, update_frequency: int = 10000): + current_file_position = 0 + line_counter = 0 + with ( + open(self.file_path, "r", encoding="utf-8") as fh, + tqdm.tqdm( + total=os.path.getsize(self.file_path), + dynamic_ncols=True, + unit="byte", + unit_scale=1, + ) as progress, + ): + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + line_counter += 1 + if line_counter == update_frequency: + new_file_pos = mmap_obj.tell() + bytes_read = new_file_pos - current_file_position + current_file_position = new_file_pos + progress.update(bytes_read) + line_counter = 0 + yield line[:-1] + + def read_and_tell(self): + current_file_position = 0 + with open(self.file_path, "r", encoding="utf8") as fh: + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + new_file_pos = mmap_obj.tell() + raw_bytes_read = new_file_pos - current_file_position + current_file_position = new_file_pos + yield line[:-1], raw_bytes_read + + def read(self): + with open(self.file_path, "r", encoding="utf8") as fh: + with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: + for line in iter(mmap_obj.readline, b""): + line = line.decode("utf-8") + yield line[:-1] + + def read_slow(self): + with open(self.file_path, "r", encoding="utf8") as fh: + while True: + line = fh.readline() + if line == -1 or line == "": + break + else: + yield line[:-1] + + +# Optimized for speed. Decompresses the archive in shell before +# using the mmap'd TextReader. +class ZStdTextReader: + def __init__(self, file) -> None: + self.file = file + + def read_tqdm(self): + decompressed_file = self.file[:-4] + print("Decompressing file, please wait...") + os.system(f"zstd -d {self.file}") # linux decompress is faster + reader = TextReader(decompressed_file) + yield from reader.read_tqdm() + os.remove(decompressed_file) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/decontaminate.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/decontaminate.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1250d39bf7cd0272e412452d970ec7c52992c5 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/decontaminate.py @@ -0,0 +1,166 @@ +import collections +import glob +import json +import os +import pickle +import random +import time + +from .archiver import ZStdTextReader +from .janitor import Janitor, word_ngrams + + +# Was used for testing the evaluator decoupled from the full logic below +def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str): + simulated_overlap = 0.1 + contaminated = int(len(docs) * simulated_overlap) + return random.sample(range(len(docs)), contaminated) + + +# Returns a dictionary containing all overlapping documents in each +# task. In the standard use case, an overlap occurs when any of the 13-grams +# found in the task document exist in the training set documents. +# +# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these +# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst" +# files. These should exist in the "ngrams_path" provided to this function. + + +# Algorithm: +# 1. Build lookups for each dataset {ngram: list(document_ids)} +# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]} +# 3. Full scan the 13-grams from the training set against the merged lookup, +# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)} +# 4. Strip the task_set from the dictionary keys and return +# +# We cache the task+set lookups as well as the overlaps. +def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict: + # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size) + + info_dict_path = os.path.join(ngrams_path, "info.json") + info_dict = json.load(open(info_dict_path, "r", encoding="utf-8")) + ngrams_n_size = info_dict["ngram_size"] + + janitor = Janitor() + + # Build lookup for each dataset first in case we use different task combinations later + print("Building Lookups...") + start = time.perf_counter() + + def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str: + return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps" + + lookups = {} + duplicates = {} # (task_name, task_set): set(doc_ids)} + sets_to_decontaminate = len(docs_by_task_set.keys()) + + for (task_name, task_set), docs in docs_by_task_set.items(): + if not os.path.exists(f"data/{task_name}"): + os.mkdir(f"data/{task_name}") + + # Check if we've decontaminated this combination before + overlaps_dump_path = get_overlaps_dump_path( + task_name, task_set, ngrams_n_size, limit + ) + if os.path.exists(overlaps_dump_path): + duplicates[(task_name, task_set)] = pickle.load( + open(overlaps_dump_path, "rb") + ) + sets_to_decontaminate -= 1 + continue + else: + duplicates[(task_name, task_set)] = set() + + # Build/load the task lookup {ngram: set(documents)}. + task_set_lookup_path = ( + f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup" + ) + if os.path.exists(task_set_lookup_path): + print(f"{task_set_lookup_path} available, loading...") + lookups[(task_name, task_set)] = pickle.load( + open(task_set_lookup_path, "rb") + ) + else: + print(f"{task_set_lookup_path} not available, building...") + lookup = collections.defaultdict(set) + + for doc_id, document in enumerate(docs): + ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size) + for ngram in ngrams: + lookup[ngram].add(doc_id) + + pickle.dump(lookup, open(task_set_lookup_path, "wb")) + lookups[(task_name, task_set)] = lookup + + elapsed = time.perf_counter() - start + print(f"Building lookups took {elapsed:0.5f} seconds.") + + matched_ngrams = [] + + if sets_to_decontaminate > 0: + print("Merging lookups...") + start = time.perf_counter() + merged_lookup = collections.defaultdict(list) + for (task_name, task_set), lookup in lookups.items(): + for ngram, doc_ids in lookup.items(): + merged_lookup[ngram].append((task_name, task_set, doc_ids)) + + elapsed = time.perf_counter() - start + print(f"Merging lookups took {elapsed:0.5f} seconds.") + + print(f"{ngrams_n_size} grams files found in {ngrams_path}:") + files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst")) + print(files) + + for file in files: + start = time.perf_counter() + print(f"Scanning {file}") + reader = ZStdTextReader(file) + total_ngrams = 0 + unique_ngrams = 0 + matching_unique = 0 + non_matching_unique = 0 + + current_ngram = "" + for line in reader.read_tqdm(): # Scan training set ngrams file + total_ngrams += 1 + [ngram, document_id] = line.rsplit(" ", 1) + if ( + ngram != current_ngram + ): # Only need to match the ngram once in training set + unique_ngrams += 1 + current_ngram = ngram + if ngram in merged_lookup: + matched_ngrams.append(ngram) # For logging + matching_unique += 1 + for task_name, task_set, doc_ids in merged_lookup[ngram]: + task_doc_set = duplicates[(task_name, task_set)] + for doc_id in doc_ids: # Record contamination across all relevant task/set combos + task_doc_set.add(doc_id) + del merged_lookup[ngram] # No point matching again + else: + non_matching_unique += 1 + + print(f"Total Ngrams: {total_ngrams}") + print(f"Unique Ngrams: {unique_ngrams}") + print(f"Unique Matching: {matching_unique}") + print(f"Unique Non Matching: {non_matching_unique}") + print("Matched ngrams:") + for ngram in matched_ngrams: + print(ngram) + + elapsed = time.perf_counter() - start + print(f"Read took {elapsed:0.5f} seconds.") + print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second") + + print(duplicates) + + # Dump overlaps separately + for (task_name, task_set), doc_ids in duplicates.items(): + overlaps_dump_path = get_overlaps_dump_path( + task_name, task_set, ngrams_n_size, limit + ) + pickle.dump(doc_ids, open(overlaps_dump_path, "wb")) + + # Strip task set and return + return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()} diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/janitor.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/janitor.py new file mode 100644 index 0000000000000000000000000000000000000000..cedf8a5717aa8156674836ba236fdcabf36e0487 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/decontamination/janitor.py @@ -0,0 +1,328 @@ +import pickle +import re +import string +import traceback +from typing import Iterator, List, Sequence, Tuple, TypeVar + + +# This is a cpp module. Compile janitor_util.cpp with: +# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup +try: + import janitor_util + + JANITOR_CPP = True +except Exception: + print("WARNING: C++ module could not be loaded. Janitor running in python mode") + traceback.print_exc() + JANITOR_CPP = False + +T = TypeVar("T") + + +# Implementation from nltk source +# https://www.nltk.org/_modules/nltk/util.html +def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]: + history = [] + while n > 1: + # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator + try: + next_item = next(sequence) + except StopIteration: + # no more data, terminate the generator + return + history.append(next_item) + n -= 1 + for item in sequence: + history.append(item) + yield tuple(history) + del history[0] + + +def word_ngrams(s: str, n: int) -> Iterator[str]: + """Splits a string into ngram words""" + tokens = s.split() # not a generator :( + ngram_seqs = form_ngrams(iter(tokens), n) + return (" ".join(ngram) for ngram in ngram_seqs) + + +# Does character sequences only - combined faster function to play around with later +# def word_ngrams_indices_combined(sequence, n): +# current_word = "" +# history = [] +# gap = False; +# start = 0 +# end = 0 +# for character in sequence: +# if character == " ": +# if not gap: +# gap = True +# history.append(current_word) +# end += len(current_word) - 1 +# current_word = "" +# if len(history) == n: +# yield (tuple(history), start, end) +# del history[0] +# start = end + 1 +# end = start +# else: +# gap = False +# current_word += character + + +# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python +def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string on whitespaces and records the indices of each in the original string. + @:return generator((word, (start_idx, end_idx)), ...) + """ + return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) + + +def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]: + """Splits a string into pairs of (ngram words, their start/end indices)""" + tokens_with_indices = split_indices(s) + + # Generator of ngrams of (word, idx_pairs) + # ( + # [(word, (start,end)), (word, (start, end))...], + # [(word, (start, end)), ...], + # ... + # ) + ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n) + + # Generator of pairs of word and index ngrams + # ( + # ([word, word, ...], [(start,end), (start,end), ...]), + # ... + # ) + ngram_indices_pairs = ( + zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices + ) + + # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...) + return ( + (" ".join(ngram_seq), (indices[0][0], indices[-1][1])) + for ngram_seq, indices in ngram_indices_pairs + ) + + +class Janitor: + # FIXME delete_chars: Should anything else go here? Special chars? + def __init__( + self, + ngram_n: int = 13, + window_to_remove: int = 200, + too_dirty_cutoff: int = 10, + minimum_slice_length: int = 200, + delete_chars: str = string.punctuation, + ) -> None: + self.ngram_n = ngram_n + self.window_to_remove = window_to_remove + self.too_dirty_cutoff = too_dirty_cutoff + self.minimum_slice_length = minimum_slice_length + self.delete_chars = delete_chars + + self.dirt_ngrams = set() + + # If in python, we'll translate uppercase to lowercase and delete naughty characters. + # This is fast by python standards + # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st + self.translation_table = str.maketrans( + string.ascii_lowercase + string.ascii_uppercase, # These characters + string.ascii_lowercase * 2, # Become these characters + self.delete_chars, # These are deleted + ) + + ############## + # I/O for saving contamination ngrams + ############## + + def save_contamination_ngrams(self, filename: str) -> None: + with open(filename, "wb") as fp: + pickle.dump(filename, fp) + + def load_contamination_ngrams(self, filename: str) -> None: + with open(filename, "rb") as fp: + self.dirt_ngrams = pickle.load(fp) + + ############## + # Call these :) + ############## + + def register_contaminant(self, dirt_string: str) -> None: + """Register a string as contamination to be removed, e.g. a test set + This breaks the dirt_string into ngrams to store for future cleaning""" + if JANITOR_CPP: + return self.register_contaminant_cpp(dirt_string) + else: + print("WARNING: Janitor running in python mode") + return self.register_contaminant_python(dirt_string) + + def clean(self, dirty_string: str) -> List[str]: + """Clean a string (e.g. a training set) by removing all ngrams previously + registered as contaminants. Returns a list of clean chunks, or empty if + the string was too dirty""" + if JANITOR_CPP: + return self.clean_cpp(dirty_string) + else: + print("WARNING: Janitor running in python mode") + return self.clean_python(dirty_string) + + def _split_chunks( + self, dirty_string: str, dirty_parts: Sequence[Tuple] + ) -> List[str]: + clean_chunks = [] + splice_idx = 0 + end = -1 + for i, (ngram, start, end) in enumerate(dirty_parts): + if i >= self.too_dirty_cutoff: + return [] + start = max(0, start - self.window_to_remove) + end = min(len(dirty_string), end + self.window_to_remove) + + if start - splice_idx > self.minimum_slice_length: + clean_chunks.append(dirty_string[splice_idx:start]) + splice_idx = end + + if end < len(dirty_string) - self.minimum_slice_length: + clean_chunks.append(dirty_string[end + 1 :]) + + return clean_chunks + + ############## + # Fast C++ + ############## + + def register_contaminant_cpp(self, dirt_string) -> None: + self.dirt_ngrams.update( + janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) + ) + + def clean_cpp(self, dirty_string: str) -> List[str]: + contamination_indices = janitor_util.clean_ngram_with_indices( + dirty_string, self.delete_chars, self.ngram_n + ) + return self._split_chunks(dirty_string, contamination_indices) + + ############## + # Slow python + ############## + + def normalize_string(self, s: str) -> str: + return s.translate(self.translation_table) + + def register_contaminant_python(self, dirt_string: str) -> None: + self.dirt_ngrams.update( + word_ngrams(self.normalize_string(dirt_string), self.ngram_n) + ) + + def clean_python(self, dirty_string: str) -> List[str]: + contamination_indices = ( + (None, *idx_pair) + for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) + if self.normalize_string(dirty_ngram) in self.dirt_ngrams + ) + return self._split_chunks(dirty_string, contamination_indices) + + +################################################################## +# Tests +################################################################# + +# def print_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 + +# for i in range(1, 10, 2): +# pprint(janitor_util.clean_ngram(source, string.punctuation, i)) +# for ngram, start, end in \ +# janitor_util.clean_ngram_with_indices(source, string.punctuation, i): +# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n")) + + +# def test_cpp(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan_python = Janitor() +# jan_cpp = Janitor() + +# jan_python.register_contaminant_python(contaminant) +# jan_cpp.register_contaminant(contaminant) + +# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams) + +# assert jan_python.clean_python(source) == jan_cpp.clean(source), \ +# (jan_python.clean_python(source), jan_cpp.clean(source)) + +# print("Passed test, python==cpp") + + +# def benchmark(): +# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html +# setup = \ +# """ +# with open("data/enwik8", "r") as f: +# data = f.read() +# jan = Janitor(too_dirty_cutoff=1000) +# jan.register_contaminant(''' +# theories is that there is a connection between "geekdom" and autism. +# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled " +# The [[Geek]] Syndrome", which is a point argued by many in the autism rights +# movement{{ref|Wired}}. This article, many professionals assert, is just one example of +# the media's application of mental disease labels to what is actually variant normal behavior +# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual +# interests, even when they seem unusual to others, are not in themselves signs of autism or +# Asperger's syndrome. Others assert that it is actually the medical profession which is applying +# mental disease labels to children who in the past would have simply been accepted as a little +# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue. +# Due to the recent publicity surrounding autism and autis +# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first, +# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first +# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties +# would last, took a cautious approach, preferring to save the revenue rather than investing it in +# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential +# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his +# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]], +# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M, +# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995), +# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the +# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the +# [[United Arab Emirates]]. After the Emirates gained independence in 1971, +# ''') +# """ + +# n = 1 +# print(f"Timing {n} run on 100 MB") +# print("Register contaminant") +# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n)) + +# print("Clean") +# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n)) +# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n)) + + +# def test_janitor_general(): +# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2 +# contaminant = "dirty boy. Clean he he" + +# jan = Janitor(ngram_n=3) +# jan.register_contaminant(contaminant) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + +# filename = "data/saved_contam" +# jan.save_contamination_ngrams(filename) + +# jan = Janitor(ngram_n=3) +# jan.load_contamination_ngrams(filename) +# cleaned = " ".join(jan.clean(source)) +# for contam in jan.dirt_ngrams: +# assert contam not in cleaned, contam + + +# if __name__ == "__main__": +# test() +# # print_cpp() +# # test_cpp() +# # benchmark() diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/evaluator.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2530e45d2e57dcae926e127ea7e074862ae8f9 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/evaluator.py @@ -0,0 +1,765 @@ +import itertools +import json +import logging +import random +import time +from collections import defaultdict +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +import torch + +import dllm_eval.api.metrics +import dllm_eval.api.registry +import dllm_eval.api.task +import dllm_eval.models +from dllm_eval.caching.cache import delete_cache +from dllm_eval.evaluator_utils import ( + consolidate_group_results, + consolidate_results, + get_sample_size, + get_subtask_list, + get_task_list, + prepare_print_tasks, + print_writeout, + run_task_tests, +) +from dllm_eval.loggers import EvaluationTracker +from dllm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash +from dllm_eval.tasks import TaskManager, get_task_dict +from dllm_eval.utils import ( + handle_non_serializable, + hash_string, + positional_deprecated, + setup_logging, + simple_parse_args_string, +) + + +if TYPE_CHECKING: + from dllm_eval.api.model import LM + from dllm_eval.api.task import Task + +eval_logger = logging.getLogger(__name__) + + +@positional_deprecated +def simple_evaluate( + model, + model_args: Optional[Union[str, dict]] = None, + tasks: Optional[List[Union[str, dict, object]]] = None, + num_fewshot: Optional[int] = None, + batch_size: Optional[Union[int, str]] = None, + max_batch_size: Optional[int] = None, + device: Optional[str] = None, + use_cache: Optional[str] = None, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + delete_requests_cache: bool = False, + limit: Optional[Union[int, float]] = None, + samples: Optional[dict] = None, + bootstrap_iters: int = 100000, + check_integrity: bool = False, + write_out: bool = False, + log_samples: bool = True, + evaluation_tracker: Optional[EvaluationTracker] = None, + system_instruction: Optional[str] = None, + apply_chat_template: Union[bool, str] = False, + fewshot_as_multiturn: bool = False, + gen_kwargs: Union[str, dict, None] = None, + task_manager: Optional[TaskManager] = None, + verbosity=None, + predict_only: bool = False, + random_seed: int = 0, + numpy_random_seed: int = 1234, + torch_random_seed: int = 1234, + fewshot_random_seed: int = 1234, + confirm_run_unsafe_code: bool = False, + metadata: Optional[dict] = None, +): + """Instantiate and evaluate a model on a list of tasks. + + :param model: Union[str, LM] + Name of model or LM object, see dllm_eval.models.get_model + :param model_args: Optional[str, dict] + String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object. + Ignored if `model` argument is a LM object. + :param tasks: list[Union[str, dict, Task]] + List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. + :param num_fewshot: int + Number of examples in few-shot context + :param batch_size: int or str, optional + Batch size for model + :param max_batch_size: int, optional + Maximal batch size to try with automatic batch size detection + :param device: str, optional + PyTorch device (e.g. "cpu" or "cuda:0") for running models + :param use_cache: str, optional + A path to a sqlite db file for caching model responses. `None` if not caching. + :param cache_requests: bool, optional + Speed up evaluation by caching the building of dataset requests. `None` if not caching. + :param rewrite_requests_cache: bool, optional + Rewrites all the request cache if set to `True`. `None` if not desired. + :param delete_requests_cache: bool, optional + Deletes all the request cache if set to `True`. `None` if not desired. + :param limit: int or float, optional + Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples. + :param samples: dictionary, optional + Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}. + :param bootstrap_iters: + Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed. + :param check_integrity: bool + Whether to run the relevant part of the test suite for the tasks + :param write_out: bool + If True, write out an example document and model input for checking task integrity + :param log_samples: bool + If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis + :param system_instruction: str + System instruction to be applied to the prompt + :param apply_chat_template: Union[bool, str] + Specifies whether to apply a chat template to the prompt. + - If set to True, the default chat template is applied. + - If set to a string, applies the specified chat template by name. + Defaults to False (no chat template applied). + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param gen_kwargs: dict or comma-separated string + Arguments for model generation + Ignored for all tasks with loglikelihood output_type + :param verbosity: str + Verbosity level for logging + :param predict_only: bool + If true only model outputs will be generated and returned. Metrics will not be evaluated + :param random_seed: int + Random seed for python's random module. If set to None, the seed will not be set. + :param numpy_random_seed: int + Random seed for numpy. If set to None, the seed will not be set. + :param torch_random_seed: int + Random seed for torch. If set to None, the seed will not be set. + :param fewshot_random_seed: int + Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None. + :param metadata: dict + Additional metadata to be added to the task manager. Will get passed to the download function of the task. + + return + Dictionary of results + """ + if verbosity is not None: + setup_logging(verbosity=verbosity) + start_date = time.time() + + if limit is not None and samples is not None: + raise ValueError( + "Either 'limit' or 'samples' must be None, but both are not None." + ) + + if ( + (isinstance(model_args, str) and "inst" in model_args.lower()) + or ( + isinstance(model_args, dict) + and any("inst" in str(v).lower() for v in model_args.values()) + ) + ) and not apply_chat_template: + eval_logger.warning( + "Model appears to be an instruct variant but chat template is not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)." + ) + + if delete_requests_cache: + eval_logger.info("Deleting requests cache...") + delete_cache() + + seed_message = [] + if random_seed is not None: + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412 + seed_message.append(f"Setting random seed to {random_seed}") + random.seed(random_seed) + + if numpy_random_seed is not None: + seed_message.append(f"Setting numpy seed to {numpy_random_seed}") + np.random.seed(numpy_random_seed) + + if torch_random_seed is not None: + seed_message.append(f"Setting torch manual seed to {torch_random_seed}") + torch.manual_seed(torch_random_seed) + + if fewshot_random_seed is not None: + seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}") + + if seed_message: + eval_logger.info(" | ".join(seed_message)) + + if tasks is None: + tasks = [] + if len(tasks) == 0: + raise ValueError( + "No tasks specified, or no tasks found. Please verify the task names." + ) + + if gen_kwargs is not None: + if isinstance(gen_kwargs, str): + gen_kwargs = simple_parse_args_string(gen_kwargs) + eval_logger.warning( + f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. " + "Ensure 'do_sample=True' for non-greedy decoding!" + ) + if not gen_kwargs: + gen_kwargs = None + + if isinstance(model, str): + if model_args is None: + eval_logger.warning("model_args not specified. Using defaults.") + model_args = "" + + if isinstance(model_args, dict): + eval_logger.info( + f"Initializing {model} model, with arguments: {model_args}" + ) + lm = dllm_eval.api.registry.get_model(model).create_from_arg_obj( + model_args, + { + "batch_size": batch_size, + "max_batch_size": max_batch_size, + "device": device, + }, + ) + + else: + eval_logger.info( + f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}" + ) + lm = dllm_eval.api.registry.get_model(model).create_from_arg_string( + model_args, + { + "batch_size": batch_size, + "max_batch_size": max_batch_size, + "device": device, + }, + ) + else: + if not isinstance(model, dllm_eval.api.model.LM): + raise TypeError( + f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of dllm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `dllm_eval.models.huggingface.HFLM(pretrained=my_model)` first." + ) + eval_logger.info("Using pre-initialized model") + lm = model + + if use_cache is not None: + eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}") + lm = dllm_eval.api.model.CachingLM( + lm, + use_cache + # each rank receives a different cache db. + # necessary to avoid multiple writes to cache at once + + "_rank" + + str(lm.rank) + + ".db", + ) + + if task_manager is None: + metadata = ( + simple_parse_args_string(model_args) + if isinstance(model_args, str) + else model_args + if isinstance(model_args, dict) + else {} + ) | (metadata or {}) + task_manager = TaskManager(metadata=metadata) + + task_dict = get_task_dict( + tasks, + task_manager, + ) + + # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. + # (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed) + def _adjust_config(task_dict): + adjusted_task_dict = {} + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + adjusted_task_dict = { + **adjusted_task_dict, + **{task_name: _adjust_config(task_obj)}, + } + + else: + if task_obj.get_config("output_type") == "generate_until": + if gen_kwargs is not None: + task_obj.set_config( + key="generation_kwargs", value=gen_kwargs, update=True + ) + eval_logger.info( + f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}" + ) + + if predict_only: + eval_logger.info( + f"Processing {task_name} in output-only mode. Metrics will not be calculated!" + ) + # we have to change the class properties post-hoc. This is pretty hacky. + task_obj.override_metric(metric_name="bypass") + + # override tasks' fewshot values to the provided num_fewshot arg value + # except if tasks have it set to 0 manually in their configs--then we should never overwrite that + if num_fewshot is not None: + if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0: + eval_logger.info( + f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored." + ) + else: + eval_logger.warning( + f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" + ) + task_obj.set_config(key="num_fewshot", value=num_fewshot) + else: + # if num_fewshot not provided, and the task does not define a default one, default to 0 + if ( + default_num_fewshot := task_obj.get_config("num_fewshot") + ) is None: + task_obj.set_config(key="num_fewshot", value=0) + # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) + task_obj.set_fewshot_seed(seed=fewshot_random_seed) + + adjusted_task_dict[task_name] = task_obj + + return adjusted_task_dict + + task_dict = _adjust_config(task_dict) + + if check_integrity: + run_task_tests(task_list=tasks) + + if evaluation_tracker is not None: + evaluation_tracker.general_config_tracker.log_experiment_args( + model_source=model, + model_args=model_args, + system_instruction=system_instruction, + chat_template=lm.chat_template(apply_chat_template) + if apply_chat_template + else None, + fewshot_as_multiturn=fewshot_as_multiturn, + ) + + results = evaluate( + lm=lm, + task_dict=task_dict, + limit=limit, + samples=samples, + cache_requests=cache_requests, + rewrite_requests_cache=rewrite_requests_cache, + bootstrap_iters=bootstrap_iters, + write_out=write_out, + log_samples=True if predict_only else log_samples, + system_instruction=system_instruction, + apply_chat_template=apply_chat_template, + fewshot_as_multiturn=fewshot_as_multiturn, + verbosity=verbosity, + confirm_run_unsafe_code=confirm_run_unsafe_code, + ) + if verbosity is not None: + setup_logging(verbosity=verbosity) + + if lm.rank == 0: + if isinstance(model, str): + model_name = model + elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"): + model_name = model.config._name_or_path + else: + model_name = type(model).__name__ + + # add info about the model and few shot config + results["config"] = { + "model": model_name, + "model_args": model_args, + } + # add more detailed model info if available + if isinstance(lm, dllm_eval.models.huggingface.HFLM): + results["config"].update(lm.get_model_info()) + # add info about execution + results["config"].update( + { + "batch_size": batch_size, + "batch_sizes": ( + list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [] + ), + "device": device, + "use_cache": use_cache, + "limit": limit, + "bootstrap_iters": bootstrap_iters, + "gen_kwargs": gen_kwargs, + "random_seed": random_seed, + "numpy_seed": numpy_random_seed, + "torch_seed": torch_random_seed, + "fewshot_seed": fewshot_random_seed, + } + ) + results["git_hash"] = get_git_commit_hash() + results["date"] = start_date + add_env_info(results) # additional environment info to results + add_tokenizer_info(results, lm) # additional info about tokenizer + return results + else: + return None + + +@positional_deprecated +def evaluate( + lm: "LM", + task_dict, + limit: Optional[int] = None, + samples: Optional[dict] = None, + cache_requests: bool = False, + rewrite_requests_cache: bool = False, + bootstrap_iters: Optional[int] = 100000, + write_out: bool = False, + log_samples: bool = True, + system_instruction: Optional[str] = None, + apply_chat_template: Union[bool, str] = False, + fewshot_as_multiturn: bool = False, + verbosity: str = "INFO", + confirm_run_unsafe_code: bool = False, +): + """Instantiate and evaluate a model on a list of tasks. + + :param lm: obj + Language Model + :param task_dict: dict[str, Task] + Dictionary of tasks. Tasks will be taken to have name type(task).config.task . + :param limit: int, optional + Limit the number of examples per task (only use this for testing) + :param samples: dictionary, optional + Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}. + :param cache_requests: bool, optional + Speed up evaluation by caching the building of dataset requests. + :param rewrite_requests_cache: bool, optional + Rewrites all the request cache if set to `True`. + :param bootstrap_iters: + Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations. + :param write_out: bool + If True, write out an example document and model input for checking task integrity + :param log_samples: bool + If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis + :param system_instruction: str + System instruction to be applied to the prompt + :param apply_chat_template: Union[bool, str] + Specifies whether to apply a chat template to the prompt. + - If set to True, the default chat template is applied. + - If set to a string, applies the specified chat template by name. + Defaults to False (no chat template applied). + :param fewshot_as_multiturn: bool + Whether to provide the fewshot examples as a multiturn conversation or a single user turn. + :param verbosity: str + Verbosity level for logging + :param confirm_run_unsafe_code: bool + Whether to confirm running tasks marked as unsafe. + :return + Dictionary of results + """ + + if limit is not None and samples is not None: + raise ValueError( + "Either 'limit' or 'samples' must be None, but both are not None." + ) + if samples is not None: + eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}") + if apply_chat_template: + eval_logger.warning( + "Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details." + ) + # tracks all Instances/requests a model must generate output on. + requests = defaultdict(list) + # stores the amount to pad out reqs per req. type so that + # number of fwd passes per distributed rank is equal + padding_requests = defaultdict(int) + + # get lists of group hierarchy and each type of request + eval_tasks = get_task_list(task_dict) + if not log_samples: + if not all( + "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() + for task_output in eval_tasks + ): + raise ValueError("log_samples must be True for 'bypass' metric-only tasks") + + # validation checks: + # 1.are we running multimodal task <-> non-multimodal model class, or vice-versa. + # 2.are we running code that is marked as unsafe. + incompatible_tasks = [] + for task_output in eval_tasks: + task: Task = task_output.task + + if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False): + incompatible_tasks.append(task_output.task_name) + elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code: + raise ValueError( + f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task." + ) + if len(incompatible_tasks) > 0: + if not getattr(lm, "MULTIMODAL", False): + raise ValueError( + f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type." + ) + # end validation check + + # Cache the limit arg. + limit_arg = limit + limits = [] + for task_output in eval_tasks: + task: Task = task_output.task + + limit = get_sample_size(task, limit_arg) + limits.append(limit) + task.build_all_requests( + limit=limit, + samples=samples.get(task_output.task_name, None) + if samples is not None + else samples, + rank=lm.rank, + world_size=lm.world_size, + cache_requests=cache_requests, + rewrite_requests_cache=rewrite_requests_cache, + system_instruction=system_instruction, + apply_chat_template=bool(apply_chat_template), + fewshot_as_multiturn=fewshot_as_multiturn, + chat_template=getattr(lm, "apply_chat_template") + if apply_chat_template + else None, + tokenizer_name=getattr(lm, "tokenizer_name", "") + if apply_chat_template + else "", + ) + eval_logger.debug( + f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" + ) + if write_out: + print_writeout(task) + # aggregate Instances by LM method requested to get output. + for instance in task.instances: + reqtype = instance.request_type + requests[reqtype].append(instance) + + if lm.world_size > 1: + instances_rnk = torch.tensor(len(task._instances), device=lm.device) + gathered_item = ( + lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() + ) + # "multiple_choice" task types dispatch (several) "loglikelihood" request types + reqtype = ( + "loglikelihood" + if task.OUTPUT_TYPE == "multiple_choice" + else task.OUTPUT_TYPE + ) + # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks) + numpad = max(gathered_item) - gathered_item[lm.rank] + # todo: may not account for padding in cases like SquadV2 which has multiple req types + padding_requests[reqtype] += numpad + + ### Run LM on inputs, get all outputs ### + # execute each type of request + for reqtype, reqs in requests.items(): + eval_logger.info(f"Running {reqtype} requests") + # create `K` copies of each request `req` based off `K = req.repeats` + cloned_reqs = [] + for req in reqs: + cloned_reqs.extend([req] * req.repeats) + + if (lm.world_size > 1) and (padding_requests[reqtype] > 0): + for _ in range(padding_requests[reqtype]): + cloned_reqs.extend([req] * req.repeats) + + # run requests through model + resps = getattr(lm, reqtype)(cloned_reqs) + + # put responses from model into a list of length K for each request. + for x, req in zip(resps, cloned_reqs): + req.resps.append(x) + + if lm.world_size > 1: + lm.accelerator.wait_for_everyone() + + RANK = lm.rank + WORLD_SIZE = lm.world_size + ### Postprocess outputs ### + # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately) + for task_output, limit in zip(eval_tasks, limits): + task = task_output.task + task.apply_filters() + + ### Collect values of metrics on all datapoints ### + # # unpack results and sort back in order and return control to Task + # TODO: make it possible to use a different metric per filter + # Pre-process task.instances to group by doc_id + instances_by_doc_id = defaultdict(list) + for instance in task.instances: + instances_by_doc_id[instance.doc_id].append(instance) + # Sort instances within each group + for instances in instances_by_doc_id.values(): + instances.sort(key=lambda x: x.idx) + # iterate over different filters used + for filter_key in task.instances[0].filtered_resps.keys(): + indices = ( + samples.get(task_output.task_name, None) + if samples is not None + else None + ) + doc_iterator = task.doc_iterator( + rank=RANK, + limit=limit, + world_size=WORLD_SIZE, + samples=indices, + ) + for doc_id, doc in doc_iterator: + if indices: + doc_id_true = indices[doc_id] + else: + doc_id_true = doc_id + requests = instances_by_doc_id[doc_id] + metrics = task.process_results( + doc, [req.filtered_resps[filter_key] for req in requests] + ) + if log_samples: + target = task.doc_to_target(doc) + example = { + "doc_id": doc_id_true, + "doc": doc, + "target": target, + "arguments": [req.args for req in requests], + "resps": [req.resps for req in requests], + "filtered_resps": [ + req.filtered_resps[filter_key] for req in requests + ], + "filter": filter_key, + "metrics": list(metrics.keys()), + "doc_hash": hash_string( + json.dumps( + requests[0].doc, + indent=2, + default=handle_non_serializable, + ensure_ascii=False, + ) + ), + "prompt_hash": hash_string(requests[0].arguments[0]), + "target_hash": hash_string(str(target)), + } + example.update(metrics) + task_output.logged_samples.append(example) + for metric, value in metrics.items(): + task_output.sample_metrics[(metric, filter_key)].append(value) + + if WORLD_SIZE > 1: + # if multigpu, then gather data across all ranks to rank 0 + # first gather logged samples across all ranks + for task_output in eval_tasks: + if log_samples: + # for task_name, task_samples in list(samples.items()): + full_samples = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.logged_samples, + object_gather_list=full_samples, + dst=0, + ) + + if RANK == 0: + task_output.logged_samples = list( + itertools.chain.from_iterable(full_samples) + ) + + # then collect metrics across all ranks + for metrics in task_output.sample_metrics: + metric_list = [None] * WORLD_SIZE if RANK == 0 else None + torch.distributed.gather_object( + obj=task_output.sample_metrics[metrics], + object_gather_list=metric_list, + dst=0, + ) + if RANK == 0: + task_output.sample_metrics[metrics] = list( + itertools.chain.from_iterable(metric_list) + ) + + if RANK == 0: + ### Aggregate results over all datapoints ### + # aggregate results ; run bootstrap CIs + for task_output in eval_tasks: + task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters) + ( + results, + samples, + configs, + versions, + num_fewshot, + higher_is_better, + ) = consolidate_results(eval_tasks) + + ### Calculate group metrics ### + if bool(results): + results, versions, show_group_table, *_ = consolidate_group_results( + results, versions, task_dict + ) + + results_agg, group_agg = prepare_print_tasks(task_dict, results) + subtask_list = get_subtask_list(task_dict) + + # collect all higher_is_better values for metrics + # in the group's subtasks. + # TODO: clean this up ; unify with the below metric_list loop? + _higher_is_better = {} + for group, task_list in subtask_list.items(): + if ( + len(task_list) != 0 + ): # subtask list will list "task_name": [] for solo tasks + for task in task_list: + for m, h in higher_is_better[task].items(): + if m not in _higher_is_better.keys(): + _higher_is_better[m] = h + + if ( + m in _higher_is_better + and _higher_is_better[m] is not None + and _higher_is_better[m] != h + ): + eval_logger.warning( + f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None." + ) + _higher_is_better[m] = None + higher_is_better[group] = _higher_is_better + + results_dict = { + "results": dict(results_agg.items()), + **( + {"groups": dict(group_agg.items())} + if (bool(group_agg) & show_group_table) + else {} + ), + "group_subtasks": dict(reversed(subtask_list.items())), + "configs": dict(sorted(configs.items())), + "versions": dict(sorted(versions.items())), + "n-shot": dict(sorted(num_fewshot.items())), + "higher_is_better": dict(sorted(higher_is_better.items())), + "n-samples": { + task_output.task_name: { + "original": len(task_output.task.eval_docs), + "effective": min( + limit if limit else len(task_output.task.eval_docs), + len(task_output.task.eval_docs), + ), + } + for task_output, limit in zip(eval_tasks, limits) + }, + } + if log_samples: + results_dict["samples"] = dict(samples) + + return results_dict + + else: + return None + + +def request_caching_arg_to_dict(cache_requests: str) -> dict: + request_caching_args = { + "cache_requests": cache_requests in {"true", "refresh"}, + "rewrite_requests_cache": cache_requests == "refresh", + "delete_requests_cache": cache_requests == "delete", + } + + return request_caching_args diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/evaluator_utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/evaluator_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a17950fe606559fb1c3c72fb3e8404759788bbe --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/evaluator_utils.py @@ -0,0 +1,554 @@ +import collections +import logging +import math +import pathlib +import sys +from typing import List, Optional, Tuple, Union + +from dllm_eval.api.group import ConfigurableGroup +from dllm_eval.api.metrics import ( + aggregate_subtask_metrics, + mean, + pooled_sample_stderr, + stderr_for_metric, +) +from dllm_eval.api.task import Task +from dllm_eval.utils import positional_deprecated + + +eval_logger = logging.getLogger(__name__) + + +class TaskOutput: + """ + Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task. + + Attributes: + task (object): The task object. + task_name (str): The name of the task. + task_config (dict): The configuration of the task. + version (str): The version of the task. + group_name (str): The name of the task group. + n_shot (int): The number of shots for the task. + task_alias (str): The alias of the task. + group_alias (str): The alias of the task group. + is_group (bool): Indicates if the task is a group. + logged_samples (list): The list of logged samples. + sample_len (int): The length of the samples. + sample_metrics (defaultdict): The dictionary of samples' metrics. + agg_metrics (defaultdict): The dictionary of aggregate metrics. + + Methods: + from_taskdict(cls, task_name: str, task): + Creates a TaskOutput instance from a task dictionary. + + calculate_aggregate_metric(bootstrap_iters=100000) -> None: + Calculates the aggregate metrics for the task. + """ + + def __init__( + self, + task=None, + task_name=None, + task_config=None, + version=None, + group_name=None, + n_shot=None, + task_alias=None, + group_alias=None, + is_group=None, + ): + self.task = task + self.task_config = task_config + self.task_name = task_name + self.group_name = group_name + self.version = version + self.n_shot = n_shot + self.task_alias = task_alias + self.group_alias = group_alias + self.is_group = is_group + self.logged_samples = [] + self.sample_len = None + self.sample_metrics = collections.defaultdict(list) + self.agg_metrics = collections.defaultdict(list) + + @classmethod + def from_taskdict(cls, task_name: str, task): + if isinstance(task, tuple): + group_name, task = task + else: + group_name = None + if not task: + # these gets filtered out in get_task_list + # once they are added to group hierarchy + is_group = True + return cls( + task=task, task_name=task_name, is_group=is_group, group_name=group_name + ) + version = task.VERSION + task_config = dict(task.dump_config()) + if (n_shot := task_config.get("num_fewshot")) == 0: + n_shot = task_config.get("metadata", {}).get("num_fewshot", 0) + task_alias = task_config.get("alias") + group_alias = task_config.get("group_alias") + return cls( + task=task, + task_name=task_name, + task_config=task_config, + group_name=group_name, + version=version, + n_shot=n_shot, + task_alias=task_alias, + group_alias=group_alias, + ) + + def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None: + for (metric, filter_key), items in self.sample_metrics.items(): + try: + agg_fn = self.task.aggregation()[metric] + except KeyError: + # This is when process results output an arbitrary metric + # TODO: Handle this better and allow other aggregate functions other than mean. + agg_fn = mean + metric_key = f"{metric},{filter_key}" + self.agg_metrics[metric_key] = agg_fn(items) + self.sample_len = len(items) # TODO: same sample size for each metric? + if isinstance(bootstrap_iters, int): + stderr_fn = stderr_for_metric( + metric=agg_fn, + bootstrap_iters=min(bootstrap_iters, 100) + if metric in ["bleu", "chrf", "ter"] + else bootstrap_iters, + ) + self.agg_metrics[f"{metric}_stderr,{filter_key}"] = ( + stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A" + ) + else: + raise ValueError( + f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations." + ) + + def __repr__(self): + return ( + f"TaskOutput(task_name={self.task_name}, " + f"group_name={self.group_name}, " + f"version={self.version}, " + f"n_shot={self.n_shot}, " + f"task_alias={self.task_alias}, " + f"group_alias={self.group_alias})" + ) + + +def get_task_list(task_dict: dict) -> List[TaskOutput]: + outputs = [] + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + _outputs = get_task_list(task_obj) + outputs.extend(_outputs) + else: + task_output = TaskOutput.from_taskdict(task_name, task_obj) + outputs.append(task_output) + + return outputs + + +def get_subtask_list(task_dict, task_root=None, depth=0): + subtask_list = {} + for group_obj, task_obj in task_dict.items(): + if isinstance(group_obj, ConfigurableGroup): + # group_name = group_obj.group_name + group_name = group_obj.group_name + else: + group_name = group_obj + if isinstance(task_obj, dict): + _subtask_list = get_subtask_list( + task_obj, task_root=group_name, depth=depth + 1 + ) + if task_root: + subtask_list.setdefault((task_root, depth), []).extend( + [ + _task + for (_task, _depth) in _subtask_list.keys() + if (_depth - 1) == depth + ] + ) + + subtask_list = {**subtask_list, **_subtask_list} + else: + if isinstance(task_obj, ConfigurableGroup): + # group_or_task_name = task_obj.group_name + group_or_task_name = task_obj.group_name + elif isinstance(task_obj, Task): + # group_or_task_name = task_obj.task_name + group_or_task_name = task_obj.task_name + + if task_root is None: + subtask_list.setdefault((group_or_task_name, depth), []) + else: + subtask_list.setdefault((task_root, depth), []).append( + group_or_task_name + ) + + if depth == 0: + _subtask_list = {} + for group_key, task_list in subtask_list.items(): + group_name, depth = group_key + _subtask_list[group_name] = task_list + subtask_list = _subtask_list + + return subtask_list + + +def print_writeout(task) -> None: + for inst in task.instances: + # print the prompt for the first few documents + if inst.doc_id < 1: + eval_logger.info( + f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\ + \n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)" + ) + eval_logger.info(f"Request: {str(inst)}") + + +def get_sample_size(task, limit: Optional[int]) -> Union[int, None]: + if limit is not None: + limit = ( + int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit) + ) + return limit + + +def prepare_print_tasks( + task_dict: dict, + results: dict, + task_depth=0, + group_depth=0, +) -> Tuple[dict, dict]: + """ + @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its + value is a list of task names. + @param results: Dictionary containing the results of each task. Each key is a + group name and its value is a dictionary of task results. + @param task_depth: The indentation level for printing the task + hierarchy. Default is 0. + @param group_depth: The indentation level for printing the group + hierarchy. Default is 0. + @return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains + aggregated results for each task, and groups_agg contains aggregated results for each group. + + Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. + """ + + def _sort_task_dict(task_dict): + """ + Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name. + Required so that we end up sorting within each sub-header correctly. + """ + + return dict( + sorted( + task_dict.items(), + key=lambda item: item[0].group_name + if isinstance(item[0], ConfigurableGroup) + else item[0], + ) + ) + + task_agg = collections.defaultdict(dict) + group_agg = collections.defaultdict(dict) + task_dict = _sort_task_dict(task_dict) + for task_or_group_name, task_or_group_obj in task_dict.items(): + tab_string = " " * task_depth + "- " if task_depth > 0 else "" + if isinstance(task_or_group_name, ConfigurableGroup): + # string_name = task_or_group_name.group_name + name = task_or_group_name.group_name + from_configurable_group = True + task_or_group_obj = _sort_task_dict(task_or_group_obj) + elif isinstance(task_or_group_name, str): + name = task_or_group_name + if isinstance(task_or_group_obj, Task): + # string_name = task_or_group_obj.task_name + name = task_or_group_obj.task_name + from_configurable_group = False + + task_agg[name] = results[name].copy() + if from_configurable_group: + if task_or_group_name.group_alias is not None: + alias = task_or_group_name.group_alias + else: + alias = task_or_group_name.group + else: + if "alias" in task_agg[name]: + alias = task_agg[name]["alias"] + else: + alias = name + + task_agg[name]["alias"] = tab_string + alias + if "samples" in task_agg[name]: + task_agg[name].pop("samples") + + if from_configurable_group and (" " not in results[name]): + group_tab_string = " " * group_depth + "- " if group_depth > 0 else "" + group_agg[name] = results[name].copy() + group_agg[name]["alias"] = group_tab_string + alias + if "samples" in group_agg[name]: + group_agg[name].pop("samples") + + if isinstance(task_or_group_obj, dict): + task_depth += 1 + group_depth += 1 + _task_agg, _group_agg = prepare_print_tasks( + task_or_group_obj, results, task_depth, group_depth + ) + task_agg = { + **task_agg, + **_task_agg, + } + group_agg = {**group_agg, **_group_agg} + task_depth -= 1 + group_depth -= 1 + return task_agg, group_agg + + +def consolidate_results( + eval_tasks: List[TaskOutput], +) -> Tuple[dict, dict, dict, dict, dict, dict]: + """ + @param eval_tasks: list(TaskOutput). + @return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot. + + Consolidates the results of multiple evaluation tasks into a single structure. + + The method iterates over each evaluation instance and extracts relevant information to create the consolidated + results structure. The consolidated results structure has the following properties: + + - results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains + metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task + aliases specified in the task configuration. + - samples: A defaultdict with task names as keys and lists of log samples as values. + - configs: A defaultdict with task names as keys and task configurations as values. + - versions: A defaultdict with task names as keys and task versions as values. + - num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values. + - higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better + for each metric as values. + + The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple. + """ + # stores the final result for each task, for each metric/filter pair. + results = collections.defaultdict(dict) + # logs info about each document evaluated. + samples = collections.defaultdict(list) + # store num-fewshot value per task + num_fewshot = collections.defaultdict(int) + # Tracks the YAML configs of all chosen task + configs = collections.defaultdict(dict) + # Tracks each task's version. + versions = collections.defaultdict(dict) + # Track `higher_is_better` for each metric + higher_is_better = collections.defaultdict(dict) + + for task_output in eval_tasks: + if "task_alias" in (task_config := task_output.task_config): + results[task_output.task_name]["alias"] = task_config["task_alias"] + else: + results[task_output.task_name]["alias"] = task_output.task_name + if group_alias := task_output.group_alias: + if group_alias not in results and (group_name := task_output.group_name): + results[group_name]["alias"] = group_alias + num_fewshot[task_output.task_name] = task_output.n_shot + configs[task_output.task_name] = task_output.task_config + versions[task_output.task_name] = task_output.version + samples[task_output.task_name] = task_output.logged_samples + higher_is_better[task_output.task_name] = task_output.task.higher_is_better() + for (metric, filter_key), items in task_output.sample_metrics.items(): + metric_key = f"{metric},{filter_key}" + results[task_output.task_name][metric_key] = task_output.agg_metrics[ + metric_key + ] + results[task_output.task_name]["samples"] = task_output.sample_len + results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = ( + task_output.agg_metrics[f"{metric}_stderr,{filter_key}"] + ) + return results, samples, configs, versions, num_fewshot, higher_is_better + + +def consolidate_group_results( + results, + versions, + task_dict, + task_root=None, + show_group_table=False, + task_aggregation_list=None, +) -> Tuple[dict, dict, bool, Union[None,]]: + """ + (Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info. + + @return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below: + + - results: A defaultdict with task names (and, after this function is called, group names of + groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys. + - versions: A defaultdict with task names (and, after this function is called, group names of + groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None). + - show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table. + - task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric. + + The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple. + In the top-level invocation of this function, task_aggregation_list is ignored. + """ + if task_root is None: + task_root = {} + + if task_aggregation_list is None: + task_aggregation_list = {} + + for group_or_task, group_or_task_info in task_dict.items(): + # Convert to string + if isinstance(group_or_task, ConfigurableGroup): + group_config = group_or_task.config + group_or_task = group_or_task.group_name + else: + group_config = None + + if isinstance(group_or_task_info, Task): + if task_root: + task_aggregation_list.setdefault(task_root, []).append( + group_or_task_info.task_name + ) + else: + ( + results, + versions, + show_group_table, + _task_aggregation_list, + ) = consolidate_group_results( + results, + versions, + group_or_task_info, + group_or_task, + show_group_table, + task_aggregation_list, + ) + if task_root: + task_aggregation_list.setdefault(task_root, []).extend( + task_aggregation_list.get(group_or_task, []) + ) + + if (group_config is None) or ( + group_config["aggregate_metric_list"] is None + ): + results[group_or_task][" "] = " " + continue + + if "aggregate_metric_list" in group_config: + agg_metric_list = group_config["aggregate_metric_list"] + + show_group_table = show_group_table | bool( + group_config["aggregate_metric_list"] + ) + + task_list = _task_aggregation_list[group_or_task] + + metric_list = list( + { + key + for task in task_list + for key in results[task].keys() + if "_stderr" not in key and key not in ["task", "alias", "samples"] + } + ) + for metric in metric_list: + stderr = "_stderr,".join(metric.split(",")) + + # gather metrics, sizes, and stderrs from subtasks + metrics = [ + results[task][metric] + for task in task_list + if metric in results[task] + ] # TODO: copy? + stderrs = [ + results[task][stderr] + for task in task_list + if stderr in results[task] + ] + sizes = [ + results[task]["samples"] + for task in task_list + if metric in results[task] + ] + + for metric_config in agg_metric_list: + for filter_name in metric_config["filter_list"]: + if metric != ",".join([metric_config["metric"], filter_name]): + continue + + # compute group's pooled metric and stderr + if metric_config["aggregation"] == "mean": + aggregate_fn = aggregate_subtask_metrics + elif callable(metric_config["aggregation"]): + aggregate_fn = metric_config["aggregation"] + else: + raise ValueError( + f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'" + ) + + results[group_or_task][metric] = aggregate_fn( + metrics, + sizes, + metric_config["weight_by_size"], + ) + # TODO: calculate groups' metrics using arbitrary agg fns + if "N/A" in stderrs: + results[group_or_task][stderr] = "N/A" + else: + # NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere + results[group_or_task][stderr] = pooled_sample_stderr( + stderrs, sizes + ) + + results[group_or_task]["samples"] = sum(sizes) + group_metadata = group_config.get("metadata", None) + if group_metadata is not None: + versions[group_or_task] = group_metadata.get("version", None) + # print(results) + return results, versions, show_group_table, task_aggregation_list + + +@positional_deprecated +def find_test_root(start_path: pathlib.Path) -> pathlib.Path: + """ + Search upward in the directory tree to a maximum of three layers + to find and return the package root (containing the 'tests' folder) + """ + cur_path = start_path.resolve() + max_layers = 3 + for _ in range(max_layers): + if (cur_path / "tests" / "test_version_stable.py").exists(): + return cur_path + else: + cur_path = cur_path.parent.resolve() + raise FileNotFoundError( + f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" + ) + + +@positional_deprecated +def run_task_tests(task_list: List[str]): + """ + Find the package root and run the tests for the given tasks + """ + import pytest + + package_root = find_test_root(start_path=pathlib.Path(__file__)) + task_string = " or ".join(task_list) + args = [ + f"{package_root}/tests/test_version_stable.py", + f"--rootdir={package_root}", + "-k", + f"{task_string}", + ] + sys.path.append(str(package_root)) + pytest_return_val = pytest.main(args) + if pytest_return_val: + raise ValueError( + f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8911d26c34cc07d1c92d20b904f48ef6fcce8ea4 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/__init__.py @@ -0,0 +1,25 @@ +from functools import partial +from typing import List + +from dllm_eval.api.filter import FilterEnsemble +from dllm_eval.api.registry import get_filter + +from . import custom, extraction, selection, transformation + + +def build_filter_ensemble( + filter_name: str, components: List[List[str]] +) -> FilterEnsemble: + """ + Create a filtering pipeline. + """ + filters = [] + for function, kwargs in components: + if kwargs is None: + kwargs = {} + # create a filter given its name in the registry + f = partial(get_filter(function), **kwargs) + # add the filter as a pipeline step + filters.append(f) + + return FilterEnsemble(name=filter_name, filters=filters) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/custom.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..07576f8a503f816de42ca1a80729edb517d75a5c --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/custom.py @@ -0,0 +1,17 @@ +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("custom") +class CustomFilter(Filter): + """ + Custom filter that applies a custom, user-defined function to the model responses. + """ + + def __init__(self, **kwargs) -> None: + self.filter_fn = kwargs.pop("filter_fn") + + super().__init__(**kwargs) + + def apply(self, resps, docs): + return self.filter_fn(resps, docs) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/decontamination.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/decontamination.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4ff15a2d856a0a191aaeb5288c3706275dddd8 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/decontamination.py @@ -0,0 +1,25 @@ +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("decontaminate") +class DecontaminationFilter(Filter): + """ + A filter which evaluates + """ + + name = "track_decontamination" + + def __init__(self, path) -> None: + """ + + TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). + should further cache result on a given (task_name, doc_id) + """ + self._decontam_results = None + + def apply(self, resps, docs) -> None: + """ + Return {"no_contamination", "only_contamination"} keys for the 2 different subsets + """ + pass diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/extraction.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..3998e7c463e5f75cff6ed19c135441cc40ba3c8b --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/extraction.py @@ -0,0 +1,233 @@ +import re +import sys +import unicodedata + +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("regex") +class RegexFilter(Filter): + """A filter that extracts values from text using regex pattern matching. + + This filter applies a regex pattern to each model response and extracts matched values. + If no match is found, returns a fallback value. Useful for extracting structured data + (like numbers) from unstructured model outputs. + """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select: int = 0, + fallback: str = "[invalid]", + ) -> None: + """ + pass a string `regex` to run `re.compile(r"regex")` on. + `fallback` defines the output returned if no matches for the regex are located. + """ + self.regex_pattern = regex_pattern + self.regex = re.compile(regex_pattern) + self.group_select = group_select + self.fallback = fallback + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + def filter_set(inst): + filtered = [] + for resp in inst: + match = self.regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m] + if match: + match = match[0] + else: + match = self.fallback + match = match.strip() + else: + match = self.fallback + filtered.append(match) + return filtered + + filtered_resps = list(map(lambda x: filter_set(x), resps)) + return filtered_resps + + +@register_filter("regex_pos") +class POSFilter(Filter): + """ """ + + def __init__( + self, + regex_pattern: str = r"\['(.*?)'\]", + group_select=0, + fallback=None, + ) -> None: + """ + pass a string `regex` to run `re.compile(r"regex")` on. + `fallback` defines the output returned if no matches for the regex are located. + """ + if fallback is None: + fallback = ["invalid"] + self.regex_pattern = regex_pattern + self.regex = re.compile(regex_pattern) + self.group_select = group_select + self.fallback = fallback + + def apply(self, resps, docs): + def extract_tagged_tokens(text): + # Extract tagged tokens list from text input using regex + tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text) + return [(token, pos) for token, pos in tokens] + + def extract_pos_tags(result): + pos_tags = [] + if isinstance(result, str): + result = extract_tagged_tokens(result) + pos_tags.extend(pos for _, pos in result) + return pos_tags if pos_tags else self.fallback + + def filter_set(inst): + filtered = [] + for resp in inst: + match = extract_pos_tags(resp) + filtered.append(match) + return filtered + + filtered_resps = map(lambda x: filter_set(x), resps) + + return filtered_resps + + +@register_filter("remove_whitespace") +class WhitespaceFilter(Filter): + """Filters out leading whitespace from responses.""" + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + def filter_set(inst): + filtered_resp = [] + for resp in inst: + resp = resp.lstrip() + filtered_resp.append(resp) + return filtered_resp + + filtered_resps = [filter_set(resp) for resp in resps] + + return filtered_resps + + +@register_filter("multi_choice_regex") +class MultiChoiceRegexFilter(RegexFilter): + """ + A filter used to extract a model's answer on multiple choice questions with + letter answers. assumes each document has a "choices" field + containing the list of answer choices and that the answer label symbols + are of the form (A), (B), (C), ... or A, B, C. + """ + + def __init__( + self, + regex_pattern: str = r"#### (\-?[0-9\.\,]+)", + group_select=0, + fallback: str = "[invalid]", + ignore_case=False, + ignore_punctuation=False, + regexes_to_ignore=None, + ) -> None: + """ + regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure + - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. + - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. + group_select: Selects the (group_select)th match from the findall result. + ignore_case: Ignores the case during step 1 matching + ignore_punctuation: Remove the punctuation during step 1 matching + regexes_to_ignore: Remove these regexes during step 1 matching + """ + super().__init__(regex_pattern, group_select, fallback) + self.ignore_case = ignore_case + self.ignore_punctuation = ignore_punctuation + self.regexes_to_ignore = regexes_to_ignore + + def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + # here, we assume we have a list, in which each element is + # a list of model responses for some particular input/target pair. + # so we process each of these (same input/target response sets) + # independently (and keep them a list.) + + def find_match(regex, resp, convert_dict={}): + match = regex.findall(resp) + if match: + match = match[self.group_select] + if isinstance(match, tuple): + match = [m for m in match if m][0] + match = match.strip() + if match and match in convert_dict: + match = convert_dict[match] + return match + + punct_tbl = dict.fromkeys( + i + for i in range(sys.maxunicode) + if unicodedata.category(chr(i)).startswith("P") + ) + + def filter_ignores(st): + if self.regexes_to_ignore is not None: + for s in self.regexes_to_ignore: + st = re.sub(s, "", st) + + if self.ignore_case: + st = st.lower() + + if self.ignore_punctuation: + # https://stackoverflow.com/a/266162 + st = st.translate(punct_tbl) + return st + + filtered_resps = [] + + for r, doc in zip(resps, docs): + fallback_regexes = [] + choice_to_alpha = {} + next_alpha = "A" + + without_paren_fallback_regexes = [] + without_paren_to_target = {} + + choices = doc["choices"] + for c in choices: + m = filter_ignores(c.strip()) + fallback_regexes.append(f"{re.escape(m)}") + choice_to_alpha[m] = f"({next_alpha})" + + without_paren_fallback_regexes.append(next_alpha) + without_paren_to_target[next_alpha] = f"({next_alpha})" + + next_alpha = chr(ord(next_alpha) + 1) + fallback_regex = re.compile("|".join(fallback_regexes)) + without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) + without_paren_fallback_regex = re.compile( + rf":[\s]*({without_paren_fallback_regex})" + ) + + filtered = [] + for resp in r: + match = find_match(self.regex, resp) + if not match: + match = find_match( + fallback_regex, filter_ignores(resp), choice_to_alpha + ) + if not match: + match = find_match( + without_paren_fallback_regex, resp, without_paren_to_target + ) + if not match: + match = self.fallback + filtered.append(match) + filtered_resps.append(filtered) + + return filtered_resps diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/selection.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/selection.py new file mode 100644 index 0000000000000000000000000000000000000000..47b9c9bc71f254c91ba92aa8578b8c9f8cb3341f --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/selection.py @@ -0,0 +1,61 @@ +from collections import Counter + +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function +# that takes an input and returns a scalar and then should select the max reward, +# or should implement different filters for different ways of handling a reward model's inference. + + +@register_filter("take_first") +class TakeFirstFilter(Filter): + def __init__(self) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + def apply(self, resps, docs): + """ + Assuming each entry of `resps` is a list of model responses, we discard all but the first response. + """ + return map(lambda r: r[0], resps) + + +@register_filter("take_first_k") +class TakeKFilter(Filter): + def __init__(self, **kwargs) -> None: + self.k = kwargs.pop("k") + + super().__init__(**kwargs) + + def apply(self, resps, docs): + # need resp to be subscriptable to check below + resps = list(resps) + # check we have at least k responses per doc, else we can't take the first k + assert len(resps[0]) >= self.k, ( + f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." + ) + return map(lambda r: r[: self.k], resps) + + +@register_filter("majority_vote") +class MajorityVoteFilter(Filter): + def __init__(self) -> None: + """ + Can define custom behavior here, if an individual instantiation of a Filter class should have state. + """ + + def apply(self, resps, docs): + """ + Each entry of `resps` is a list of model responses. + We select the response that occurs most frequently in each entry of `resps`. + """ + + def select_majority(resp): + counts = Counter(resp) + vote = counts.most_common(1)[0][0] + return vote + + return map(lambda r: [select_majority(r)], resps) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/transformation.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..48d2a21d7d510991977ebcf6601c2e7437ecb4bb --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/filters/transformation.py @@ -0,0 +1,122 @@ +import re + +from dllm_eval.api.filter import Filter +from dllm_eval.api.registry import register_filter + + +@register_filter("lowercase") +class LowercaseFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def filter_set(inst): + return [resp.lower() for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("uppercase") +class UppercaseFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def filter_set(inst): + return [resp.upper() for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("map") +class MapFilter(Filter): + def __init__(self, mapping_dict: dict = None, default_value=None) -> None: + """ + Initializes the MapFilter with a given mapping dictionary and default value. + + Args: + - mapping_dict (dict): A dictionary containing the key-value mappings. + Default is an empty dictionary. + - default_value (Any): The value to be returned when a key is not found in the mapping_dict. + Default is None. + + Example: + mapper = MapFilter({'A': 1, 'B': 2}, default_value=0) + """ + if mapping_dict is None: + mapping_dict = {} + assert isinstance(mapping_dict, dict), ( + "Provided mapping_dict is not a dictionary" + ) + self.mapping_dict = mapping_dict + self.default_value = default_value + + def apply(self, resps, docs): + def filter_set(inst): + return [self.mapping_dict.get(resp, self.default_value) for resp in inst] + + return [filter_set(resp) for resp in resps] + + +@register_filter("format_span") +class SPANFilter(Filter): + def __init__(self) -> None: + pass + + def apply(self, resps, docs): + def format_ner_text(text): + label_dict = { + "person": "PER", + "location": "LOC", + "organization": "ORG", + "counties": "LOC", + "places": "LOC", + "people": "PER", + "persons": "PER", + "company": "ORG", + "country": "LOC", + "continent": "LOC", + "time": "DATE", + "date": "DATE", + "per": "PER", + "loc": "LOC", + "org": "ORG", + } + text = text.lower() + for key, value in label_dict.items(): + text = text.replace(key, value) + + text = "$".join(i for i in text.split("$$")) + return text.rstrip("$$") + + def format_named_entities(text): + """ + Extract named entities from text and format them as 'label: value $$ label: value'. + Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values. + """ + # Regular expression to match label: entities pattern + pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)" + # Normalize newline characters + text = text.replace("\n", "$").strip() + matches = re.findall(pattern, text) + + formatted_entities = [] + + for label, values in matches: + # Split multiple entities separated by commas and strip whitespace + entities = [value.strip() for value in values.split(",")] + + # Exclude 'none' entities + for entity in entities: + if entity.lower() != "none": + formatted_entities.append(f"{label.lower()}: {entity}") + + # Join entities with the desired separator + return " $ ".join(formatted_entities) + + def filter_set(inst): + return [ + format_named_entities(format_ner_text(resp.lower())) for resp in inst + ] + + return [filter_set(resp) for resp in resps] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02b7a6834c6486fde35ef02d715e90be3fba223a --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/__init__.py @@ -0,0 +1,2 @@ +from .evaluation_tracker import EvaluationTracker +from .wandb_logger import WandbLogger diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/evaluation_tracker.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/evaluation_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7f88978e73a8fad88d83a9563e85090b8c7e5594 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/evaluation_tracker.py @@ -0,0 +1,530 @@ +import json +import logging +import os +import re +import time +from collections import defaultdict +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path + +from datasets import load_dataset +from datasets.utils.metadata import MetadataConfigs +from huggingface_hub import ( + DatasetCard, + DatasetCardData, + HfApi, + hf_hub_url, +) +from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status + +from dllm_eval.utils import ( + get_file_datetime, + get_file_task_name, + get_results_filenames, + get_sample_results_filenames, + handle_non_serializable, + hash_string, + sanitize_list, + sanitize_model_name, + sanitize_task_name, +) + + +eval_logger = logging.getLogger(__name__) + + +@dataclass(init=False) +class GeneralConfigTracker: + """ + Tracker for the evaluation parameters. + + Attributes: + model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.) + model_name (str): Name of the model. + model_name_sanitized (str): Sanitized model name for directory creation. + start_time (float): Start time of the experiment. Logged at class init. + end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`] + total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times). + """ + + model_source: str = None + model_name: str = None + model_name_sanitized: str = None + system_instruction: str = None + system_instruction_sha: str = None + fewshot_as_multiturn: bool = None + chat_template: str = None + chat_template_sha: str = None + start_time: float = None + end_time: float = None + total_evaluation_time_seconds: str = None + + def __init__(self) -> None: + """Starts the evaluation timer.""" + self.start_time = time.perf_counter() + + @staticmethod + def _get_model_name(model_args: str) -> str: + """Extracts the model name from the model arguments.""" + + def extract_model_name(model_args: str, key: str) -> str: + """Extracts the model name from the model arguments using a key.""" + args_after_key = model_args.split(key)[1] + return args_after_key.split(",")[0] + + # order does matter, e.g. peft and delta are provided together with pretrained + prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="] + for prefix in prefixes: + if prefix in model_args: + return extract_model_name(model_args, prefix) + return "" + + def log_experiment_args( + self, + model_source: str, + model_args: str, + system_instruction: str, + chat_template: str, + fewshot_as_multiturn: bool, + ) -> None: + """Logs model parameters and job ID.""" + self.model_source = model_source + self.model_name = GeneralConfigTracker._get_model_name(model_args) + self.model_name_sanitized = sanitize_model_name(self.model_name) + self.system_instruction = system_instruction + self.system_instruction_sha = ( + hash_string(system_instruction) if system_instruction else None + ) + self.chat_template = chat_template + self.chat_template_sha = hash_string(chat_template) if chat_template else None + self.fewshot_as_multiturn = fewshot_as_multiturn + + def log_end_time(self) -> None: + """Logs the end time of the evaluation and calculates the total evaluation time.""" + self.end_time = time.perf_counter() + self.total_evaluation_time_seconds = str(self.end_time - self.start_time) + + +class EvaluationTracker: + """ + Keeps track and saves relevant information of the evaluation process. + Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested. + """ + + def __init__( + self, + output_path: str = None, + hub_results_org: str = "", + hub_repo_name: str = "", + details_repo_name: str = "", + results_repo_name: str = "", + push_results_to_hub: bool = False, + push_samples_to_hub: bool = False, + public_repo: bool = False, + token: str = "", + leaderboard_url: str = "", + point_of_contact: str = "", + gated: bool = False, + ) -> None: + """ + Creates all the necessary loggers for evaluation tracking. + + Args: + output_path (str): Path to save the results. If not provided, the results won't be saved. + hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token. + hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`. + details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`. + result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo. + push_results_to_hub (bool): Whether to push the results to the Hugging Face hub. + push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub. + public_repo (bool): Whether to push the results to a public or private repository. + token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`. + leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card. + point_of_contact (str): Contact information on the Hugging Face hub dataset card. + gated (bool): Whether to gate the repository. + """ + self.general_config_tracker = GeneralConfigTracker() + + self.output_path = output_path + self.push_results_to_hub = push_results_to_hub + self.push_samples_to_hub = push_samples_to_hub + self.public_repo = public_repo + self.leaderboard_url = leaderboard_url + self.point_of_contact = point_of_contact + self.api = HfApi(token=token) if token else None + self.gated_repo = gated + + if not self.api and (push_results_to_hub or push_samples_to_hub): + raise ValueError( + "Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. " + "Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable." + ) + + if ( + self.api + and hub_results_org == "" + and (push_results_to_hub or push_samples_to_hub) + ): + hub_results_org = self.api.whoami()["name"] + eval_logger.warning( + f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'." + ) + + if hub_repo_name == "": + details_repo_name = ( + details_repo_name if details_repo_name != "" else "lm-eval-results" + ) + results_repo_name = ( + results_repo_name if results_repo_name != "" else details_repo_name + ) + else: + details_repo_name = hub_repo_name + results_repo_name = hub_repo_name + eval_logger.warning( + "hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead." + ) + + self.details_repo = f"{hub_results_org}/{details_repo_name}" + self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private" + self.results_repo = f"{hub_results_org}/{results_repo_name}" + self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private" + + def save_results_aggregated( + self, + results: dict, + samples: dict, + ) -> None: + """ + Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested. + + Args: + results (dict): The aggregated results to save. + samples (dict): The samples results to save. + """ + self.general_config_tracker.log_end_time() + + if self.output_path: + try: + eval_logger.info("Saving results aggregated") + + # calculate cumulative hash for each task - only if samples are provided + task_hashes = {} + if samples: + for task_name, task_samples in samples.items(): + sample_hashes = [ + s["doc_hash"] + s["prompt_hash"] + s["target_hash"] + for s in task_samples + ] + task_hashes[task_name] = hash_string("".join(sample_hashes)) + + # update initial results dict + results.update({"task_hashes": task_hashes}) + results.update(asdict(self.general_config_tracker)) + dumped = json.dumps( + results, + indent=2, + default=handle_non_serializable, + ensure_ascii=False, + ) + + path = Path(self.output_path if self.output_path else Path.cwd()) + self.date_id = datetime.now().isoformat().replace(":", "-") + if path.suffix == ".json": + path.parent.mkdir(parents=True, exist_ok=True) + file_results_aggregated = path.with_name( + f"{path.stem}_{self.date_id}.json" + ) + else: + path.mkdir(parents=True, exist_ok=True) + file_results_aggregated = path.joinpath( + f"results_{self.date_id}.json" + ) + + file_results_aggregated.open("w", encoding="utf-8").write(dumped) + + if self.api and self.push_results_to_hub: + repo_id = ( + self.results_repo + if self.public_repo + else self.results_repo_private + ) + self.api.create_repo( + repo_id=repo_id, + repo_type="dataset", + private=not self.public_repo, + exist_ok=True, + ) + self.api.upload_file( + repo_id=repo_id, + path_or_fileobj=str(file_results_aggregated), + path_in_repo=os.path.join( + self.general_config_tracker.model_name, + file_results_aggregated.name, + ), + repo_type="dataset", + commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}", + ) + eval_logger.info( + "Successfully pushed aggregated results to the Hugging Face Hub. " + f"You can find them at: {repo_id}" + ) + + except Exception as e: + eval_logger.warning("Could not save results aggregated") + eval_logger.info(repr(e)) + else: + eval_logger.info( + "Output path not provided, skipping saving results aggregated" + ) + + def save_results_samples( + self, + task_name: str, + samples: dict, + ) -> None: + """ + Saves the samples results to the output path and pushes them to the Hugging Face hub if requested. + + Args: + task_name (str): The task name to save the samples for. + samples (dict): The samples results to save. + """ + if self.output_path: + try: + eval_logger.info(f"Saving per-sample results for: {task_name}") + + path = Path(self.output_path if self.output_path else Path.cwd()) + if path.suffix == ".json": + path = path.parent + path.mkdir(parents=True, exist_ok=True) + + file_results_samples = path.joinpath( + f"samples_{task_name}_{self.date_id}.jsonl" + ) + + for sample in samples: + # we first need to sanitize arguments and resps + # otherwise we won't be able to load the dataset + # using the datasets library + arguments = {} + for i, arg in enumerate(sample["arguments"]): + arguments[f"gen_args_{i}"] = {} + for j, tmp in enumerate(arg): + arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp + + sample["resps"] = sanitize_list(sample["resps"]) + sample["filtered_resps"] = sanitize_list(sample["filtered_resps"]) + sample["arguments"] = arguments + sample["target"] = str(sample["target"]) + + sample_dump = ( + json.dumps( + sample, + default=handle_non_serializable, + ensure_ascii=False, + ) + + "\n" + ) + + with open(file_results_samples, "a", encoding="utf-8") as f: + f.write(sample_dump) + + if self.api and self.push_samples_to_hub: + repo_id = ( + self.details_repo + if self.public_repo + else self.details_repo_private + ) + self.api.create_repo( + repo_id=repo_id, + repo_type="dataset", + private=not self.public_repo, + exist_ok=True, + ) + try: + if self.gated_repo: + headers = build_hf_headers() + r = get_session().put( + url=f"https://huggingface.co/api/datasets/{repo_id}/settings", + headers=headers, + json={"gated": "auto"}, + ) + hf_raise_for_status(r) + except Exception as e: + eval_logger.warning("Could not gate the repository") + eval_logger.info(repr(e)) + self.api.upload_folder( + repo_id=repo_id, + folder_path=str(path), + path_in_repo=self.general_config_tracker.model_name_sanitized, + repo_type="dataset", + commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}", + ) + eval_logger.info( + f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. " + f"You can find them at: {repo_id}" + ) + + except Exception as e: + eval_logger.warning("Could not save sample results") + eval_logger.info(repr(e)) + else: + eval_logger.info("Output path not provided, skipping saving sample results") + + def recreate_metadata_card(self) -> None: + """ + Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub. + """ + + eval_logger.info("Recreating metadata card") + repo_id = self.details_repo if self.public_repo else self.details_repo_private + + files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset") + results_files = get_results_filenames(files_in_repo) + sample_files = get_sample_results_filenames(files_in_repo) + + # Build a dictionary to store the latest evaluation datetime for: + # - Each tested model and its aggregated results + # - Each task and sample results, if existing + # i.e. { + # "org__model_name__gsm8k": "2021-09-01T12:00:00", + # "org__model_name__ifeval": "2021-09-01T12:00:00", + # "org__model_name__results": "2021-09-01T12:00:00" + # } + latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat()) + + for file_path in sample_files: + file_path = Path(file_path) + filename = file_path.name + model_name = file_path.parent + task_name = get_file_task_name(filename) + results_datetime = get_file_datetime(filename) + task_name_sanitized = sanitize_task_name(task_name) + # Results and sample results for the same model and task will have the same datetime + samples_key = f"{model_name}__{task_name_sanitized}" + results_key = f"{model_name}__results" + latest_datetime = max( + latest_task_results_datetime[samples_key], + results_datetime, + ) + latest_task_results_datetime[samples_key] = latest_datetime + latest_task_results_datetime[results_key] = max( + latest_task_results_datetime[results_key], + latest_datetime, + ) + + # Create metadata card + card_metadata = MetadataConfigs() + + # Add the latest aggregated results to the metadata card for easy access + for file_path in results_files: + file_path = Path(file_path) + results_filename = file_path.name + model_name = file_path.parent + eval_date = get_file_datetime(results_filename) + eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date) + results_filename = Path("**") / Path(results_filename).name + config_name = f"{model_name}__results" + sanitized_last_eval_date_results = re.sub( + r"[^\w\.]", "_", latest_task_results_datetime[config_name] + ) + + if eval_date_sanitized == sanitized_last_eval_date_results: + # Ensure that all results files are listed in the metadata card + current_results = card_metadata.get(config_name, {"data_files": []}) + current_results["data_files"].append( + {"split": eval_date_sanitized, "path": [str(results_filename)]} + ) + card_metadata[config_name] = current_results + # If the results file is the newest, update the "latest" field in the metadata card + card_metadata[config_name]["data_files"].append( + {"split": "latest", "path": [str(results_filename)]} + ) + + # Add the tasks details configs + for file_path in sample_files: + file_path = Path(file_path) + filename = file_path.name + model_name = file_path.parent + task_name = get_file_task_name(filename) + eval_date = get_file_datetime(filename) + task_name_sanitized = sanitize_task_name(task_name) + eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date) + results_filename = Path("**") / Path(filename).name + config_name = f"{model_name}__{task_name_sanitized}" + sanitized_last_eval_date_results = re.sub( + r"[^\w\.]", "_", latest_task_results_datetime[config_name] + ) + if eval_date_sanitized == sanitized_last_eval_date_results: + # Ensure that all sample results files are listed in the metadata card + current_details_for_task = card_metadata.get( + config_name, {"data_files": []} + ) + current_details_for_task["data_files"].append( + {"split": eval_date_sanitized, "path": [str(results_filename)]} + ) + card_metadata[config_name] = current_details_for_task + # If the samples results file is the newest, update the "latest" field in the metadata card + card_metadata[config_name]["data_files"].append( + {"split": "latest", "path": [str(results_filename)]} + ) + + # Get latest results and extract info to update metadata card examples + latest_datetime = max(latest_task_results_datetime.values()) + latest_model_name = max( + latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k] + ) + last_results_file = [ + f for f in results_files if latest_datetime.replace(":", "-") in f + ][0] + last_results_file_path = hf_hub_url( + repo_id=repo_id, filename=last_results_file, repo_type="dataset" + ) + latest_results_file = load_dataset( + "json", data_files=last_results_file_path, split="train" + ) + results_dict = latest_results_file["results"][0] + new_dictionary = {"all": results_dict} + new_dictionary.update(results_dict) + results_string = json.dumps(new_dictionary, indent=4) + + dataset_summary = ( + "Dataset automatically created during the evaluation run of model " + ) + if self.general_config_tracker.model_source == "hf": + dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n" + else: + dataset_summary += f"{self.general_config_tracker.model_name}\n" + dataset_summary += ( + f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n" + f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each " + 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n' + 'An additional configuration "results" store all the aggregated results of the run.\n\n' + "To load the details from a run, you can for instance do the following:\n" + ) + if self.general_config_tracker.model_source == "hf": + dataset_summary += ( + "```python\nfrom datasets import load_dataset\n" + f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n' + ) + dataset_summary += ( + "## Latest results\n\n" + f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) " + "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. " + 'You find each in the results and the "latest" split for each eval):\n\n' + f"```python\n{results_string}\n```" + ) + card_data = DatasetCardData( + dataset_summary=dataset_summary, + repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}", + pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}", + leaderboard_url=self.leaderboard_url, + point_of_contact=self.point_of_contact, + ) + card_metadata.to_dataset_card_data(card_data) + card = DatasetCard.from_template( + card_data, + pretty_name=card_data.pretty_name, + ) + card.push_to_hub(repo_id, repo_type="dataset") diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ba795edb72d7b665a2c0fe6d4f3e3a5ed91b6940 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/utils.py @@ -0,0 +1,149 @@ +import logging +import os +import re +import subprocess +from importlib.metadata import version +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +from torch.utils.collect_env import get_pretty_env_info +from transformers import __version__ as trans_version + + +logger = logging.getLogger(__name__) + + +def remove_none_pattern(input_string: str) -> Tuple[str, bool]: + """Remove the ',none' substring from the input_string if it exists at the end. + + Args: + input_string (str): The input string from which to remove the ',none' substring. + + Returns: + Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed + and a boolean indicating whether the modification was made (True) or not (False). + """ + # Define the pattern to match ',none' at the end of the string + pattern = re.compile(r",none$") + + # Use sub() to replace ',none' with an empty string + result = re.sub(pattern, "", input_string) + + # check if the input_string changed + removed = result != input_string + + return result, removed + + +def _handle_non_serializable(o: Any) -> Union[int, str, list]: + """Handle non-serializable objects by converting them to serializable types. + + Args: + o (Any): The object to be handled. + + Returns: + Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32, + it will be converted to int. If the object is of type set, it will be converted + to a list. Otherwise, it will be converted to str. + """ + if isinstance(o, np.int64) or isinstance(o, np.int32): + return int(o) + elif isinstance(o, set): + return list(o) + else: + return str(o) + + +def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]: + try: + git_folder = Path(repo_path, ".git") + if git_folder.is_file(): + git_folder = Path( + git_folder.parent, + git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1], + ) + if Path(git_folder, "HEAD").exists(): + head_name = ( + Path(git_folder, "HEAD") + .read_text(encoding="utf-8") + .split("\n")[0] + .split(" ")[-1] + ) + head_ref = Path(git_folder, head_name) + git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "") + else: + git_hash = None + except Exception as err: + logger.debug( + f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}" + ) + return None + return git_hash + + +def get_git_commit_hash(): + """ + Gets the git commit hash of your current repo (if it exists). + Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42 + """ + try: + git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() + git_hash = git_hash.decode() + except (subprocess.CalledProcessError, FileNotFoundError): + # FileNotFoundError occurs when git not installed on system + git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists + return git_hash + + +def add_env_info(storage: Dict[str, Any]): + try: + pretty_env_info = get_pretty_env_info() + except Exception as err: + pretty_env_info = str(err) + try: + dllm_eval_version = version("dllm_eval") + except Exception as err: + dllm_eval_version = str(err) + transformers_version = trans_version + upper_dir_commit = get_commit_from_path( + Path(os.getcwd(), "..") + ) # git hash of upper repo if exists + added_info = { + "pretty_env_info": pretty_env_info, + "transformers_version": transformers_version, + "dllm_eval_version": dllm_eval_version, + "upper_git_hash": upper_dir_commit, # in case this repo is submodule + } + storage.update(added_info) + + +def add_tokenizer_info(storage: Dict[str, Any], lm): + if getattr(lm, "tokenizer", False): + try: + tokenizer_info = { + "tokenizer_pad_token": [ + lm.tokenizer.pad_token, + str(lm.tokenizer.pad_token_id), + ], + "tokenizer_eos_token": [ + lm.tokenizer.eos_token, + str(lm.tokenizer.eos_token_id), + ], + "tokenizer_bos_token": [ + lm.tokenizer.bos_token, + str(lm.tokenizer.bos_token_id), + ], + "eot_token_id": getattr(lm, "eot_token_id", None), + "max_length": getattr(lm, "max_length", None), + } + storage.update(tokenizer_info) + except Exception as err: + logger.debug( + f"Logging detailed tokenizer info failed with {err}, skipping..." + ) + # seems gguf and textsynth do not have tokenizer + else: + logger.debug( + "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results." + ) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/wandb_logger.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/wandb_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0859b3c8e90437f21b6f06143b14941a7a96d2 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/loggers/wandb_logger.py @@ -0,0 +1,358 @@ +import copy +import json +import logging +from typing import Any, Dict, List, Literal, Tuple + +import numpy as np +import pandas as pd +from packaging.version import Version + +from dllm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern + + +logger = logging.getLogger(__name__) + + +def get_wandb_printer() -> Literal["Printer"]: + """Returns a wandb printer instance for pretty stdout.""" + from wandb.sdk.lib.printer import new_printer + + printer = new_printer() + return printer + + +class WandbLogger: + def __init__(self, init_args=None, config_args=None) -> None: + """Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update() + + Args: + init_args Optional[Dict]: Arguments for init configuration. + config_args Optional[Dict]: Arguments for config + + Parse and log the results returned from evaluator.simple_evaluate() with: + wandb_logger.post_init(results) + wandb_logger.log_eval_result() + wandb_logger.log_eval_samples(results["samples"]) + """ + try: + import wandb + + assert Version(wandb.__version__) >= Version("0.13.6") + if Version(wandb.__version__) < Version("0.13.6"): + wandb.require("report-editing:v0") + except Exception as e: + logger.warning( + "To use the wandb reporting functionality please install wandb>=0.13.6.\n" + "To install the latest version of wandb run `pip install wandb --upgrade`\n" + f"{e}" + ) + + self.wandb_args: Dict[str, Any] = init_args or {} + self.wandb_config_args: Dict[str, Any] = config_args or {} + + # pop the step key from the args to save for all logging calls + self.step = self.wandb_args.pop("step", None) + + # initialize a W&B run + if wandb.run is None: + self.run = wandb.init(**self.wandb_args) + if self.wandb_config_args: + self.run.config.update(self.wandb_config_args) + else: + self.run = wandb.run + + self.printer = get_wandb_printer() + + def post_init(self, results: Dict[str, Any]) -> None: + self.results: Dict[str, Any] = copy.deepcopy(results) + self.task_names: List[str] = list(results.get("results", {}).keys()) + self.group_names: List[str] = list(results.get("groups", {}).keys()) + + def _get_config(self) -> Dict[str, Any]: + """Get configuration parameters.""" + self.task_configs = self.results.get("configs", {}) + cli_configs = self.results.get("config", {}) + configs = { + "task_configs": self.task_configs, + "cli_configs": cli_configs, + } + + return configs + + def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]: + """Sanitize the results dictionary.""" + _results = copy.deepcopy(self.results.get("results", dict())) + + # Remove None from the metric string name + tmp_results = copy.deepcopy(_results) + for task_name in self.task_names: + task_result = tmp_results.get(task_name, dict()) + for metric_name, metric_value in task_result.items(): + _metric_name, removed = remove_none_pattern(metric_name) + if removed: + _results[task_name][_metric_name] = metric_value + _results[task_name].pop(metric_name) + + # remove string valued keys from the results dict + wandb_summary = {} + for task in self.task_names: + task_result = _results.get(task, dict()) + for metric_name, metric_value in task_result.items(): + if isinstance(metric_value, str): + wandb_summary[f"{task}/{metric_name}"] = metric_value + + for summary_metric, summary_value in wandb_summary.items(): + _task, _summary_metric = summary_metric.split("/") + _results[_task].pop(_summary_metric) + + tmp_results = copy.deepcopy(_results) + for task_name, task_results in tmp_results.items(): + for metric_name, metric_value in task_results.items(): + _results[f"{task_name}/{metric_name}"] = metric_value + _results[task_name].pop(metric_name) + for task in self.task_names: + _results.pop(task) + + return wandb_summary, _results + + def _log_results_as_table(self) -> None: + """Generate and log evaluation results as a table to W&B.""" + columns = [ + "Version", + "Filter", + "num_fewshot", + "Metric", + "Value", + "Stderr", + ] + + def make_table(columns: List[str], key: str = "results"): + import wandb + + table = wandb.Table(columns=columns) + results = copy.deepcopy(self.results) + + for k, dic in results.get(key).items(): + if k in self.group_names and not key == "groups": + continue + version = results.get("versions").get(k) + if version == "N/A": + version = None + n = results.get("n-shot").get(k) + + for (mf), v in dic.items(): + m, _, f = mf.partition(",") + if m.endswith("_stderr"): + continue + if m == "alias": + continue + + if m + "_stderr" + "," + f in dic: + se = dic[m + "_stderr" + "," + f] + if se != "N/A": + se = "%.4f" % se + table.add_data(*[k, version, f, n, m, str(v), str(se)]) + else: + table.add_data(*[k, version, f, n, m, str(v), ""]) + + return table + + # log the complete eval result to W&B Table + table = make_table(["Tasks"] + columns, "results") + self.run.log({"evaluation/eval_results": table}, step=self.step) + + if "groups" in self.results.keys(): + table = make_table(["Groups"] + columns, "groups") + self.run.log({"evaluation/group_eval_results": table}, step=self.step) + + def _log_results_as_artifact(self) -> None: + """Log results as JSON artifact to W&B.""" + import wandb + + dumped = json.dumps( + self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False + ) + artifact = wandb.Artifact("results", type="eval_results") + with artifact.new_file("results.json", mode="w", encoding="utf-8") as f: + f.write(dumped) + self.run.log_artifact(artifact) + + def log_eval_result(self) -> None: + """Log evaluation results to W&B.""" + # Log configs to wandb + configs = self._get_config() + self.run.config.update(configs, allow_val_change=self.step is not None) + + wandb_summary, self.wandb_results = self._sanitize_results_dict() + # update wandb.run.summary with items that were removed + self.run.summary.update(wandb_summary) + # Log the evaluation metrics to wandb + self.run.log(self.wandb_results, step=self.step) + # Log the evaluation metrics as W&B Table + self._log_results_as_table() + # Log the results dict as json to W&B Artifacts + self._log_results_as_artifact() + + def _generate_dataset( + self, data: List[Dict[str, Any]], config: Dict[str, Any] + ) -> pd.DataFrame: + """Generate a dataset from evaluation data. + + Args: + data (List[Dict[str, Any]]): The data to generate a dataset for. + config (Dict[str, Any]): The configuration of the task. + + Returns: + pd.DataFrame: A dataframe that is ready to be uploaded to W&B. + """ + ids = [x["doc_id"] for x in data] + labels = [x["target"] for x in data] + instance = [""] * len(ids) + resps = [""] * len(ids) + filtered_resps = [""] * len(ids) + model_outputs = {} + + metrics_list = config["metric_list"] + metrics = {} + for metric in metrics_list: + metric = metric.get("metric") + if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data] + if metric in ["byte_perplexity", "bits_per_byte"]: + metrics[f"{metric}_bytes"] = [x[metric][1] for x in data] + else: + metrics[f"{metric}_words"] = [x[metric][1] for x in data] + else: + metrics[metric] = [x[metric] for x in data] + + if config["output_type"] == "loglikelihood": + instance = [x["arguments"][0][0] for x in data] + labels = [x["arguments"][0][1] for x in data] + resps = [ + f"log probability of continuation is {x['resps'][0][0][0]} " + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["resps"][0][0][1] else "be" + ) + for x in data + ] + filtered_resps = [ + f"log probability of continuation is {x['filtered_resps'][0][0]} " + + "\n\n" + + "continuation will {} generated with greedy sampling".format( + "not be" if not x["filtered_resps"][0][1] else "be" + ) + for x in data + ] + elif config["output_type"] == "multiple_choice": + instance = [x["arguments"][0][0] for x in data] + choices = [ + "\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])]) + for x in data + ] + resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data] + filtered_resps = [ + np.argmax([n[0] for n in x["filtered_resps"]]) for x in data + ] + elif config["output_type"] == "loglikelihood_rolling": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + elif config["output_type"] == "generate_until": + instance = [x["arguments"][0][0] for x in data] + resps = [x["resps"][0][0] for x in data] + filtered_resps = [x["filtered_resps"][0] for x in data] + + model_outputs["raw_predictions"] = resps + model_outputs["filtered_predictions"] = filtered_resps + + df_data = { + "id": ids, + "data": instance, + } + if config["output_type"] == "multiple_choice": + df_data["choices"] = choices + + tmp_data = { + "input_len": [len(x) for x in instance], + "labels": labels, + "output_type": config["output_type"], + } + df_data.update(tmp_data) + df_data.update(model_outputs) + df_data.update(metrics) + + return pd.DataFrame(df_data) + + def _log_samples_as_artifact( + self, data: List[Dict[str, Any]], task_name: str + ) -> None: + import wandb + + # log the samples as an artifact + dumped = json.dumps( + data, + indent=2, + default=_handle_non_serializable, + ensure_ascii=False, + ) + artifact = wandb.Artifact(f"{task_name}", type="samples_by_task") + with artifact.new_file( + f"{task_name}_eval_samples.json", mode="w", encoding="utf-8" + ) as f: + f.write(dumped) + self.run.log_artifact(artifact) + # artifact.wait() + + def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None: + """Log evaluation samples to W&B. + + Args: + samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task. + """ + task_names: List[str] = [ + x for x in self.task_names if x not in self.group_names + ] + + ungrouped_tasks = [] + tasks_by_groups = {} + + for task_name in task_names: + group_names = self.task_configs[task_name].get("group", None) + if group_names: + if isinstance(group_names, str): + group_names = [group_names] + + for group_name in group_names: + if not tasks_by_groups.get(group_name): + tasks_by_groups[group_name] = [task_name] + else: + tasks_by_groups[group_name].append(task_name) + else: + ungrouped_tasks.append(task_name) + + for task_name in ungrouped_tasks: + eval_preds = samples[task_name] + + # log the samples as a W&B Table + df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) + self.run.log({f"{task_name}_eval_results": df}, step=self.step) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + for group, grouped_tasks in tasks_by_groups.items(): + grouped_df = pd.DataFrame() + for task_name in grouped_tasks: + eval_preds = samples[task_name] + df = self._generate_dataset( + eval_preds, self.task_configs.get(task_name) + ) + df["group"] = group + df["task"] = task_name + grouped_df = pd.concat([grouped_df, df], ignore_index=True) + + # log the samples as a json file as W&B Artifact + self._log_samples_as_artifact(eval_preds, task_name) + + self.run.log({f"{group}_eval_results": grouped_df}, step=self.step) diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/LLaDA2.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/LLaDA2.py new file mode 100644 index 0000000000000000000000000000000000000000..783400310b4342ecc7a671926fa8e7afe3b05620 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/LLaDA2.py @@ -0,0 +1,726 @@ +import logging +import os +from datetime import timedelta +from typing import Dict, List, Literal, Optional, Tuple, Union, TypeVar +import torch +import torch.nn.functional as F +import numpy as np +import transformers +import json +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, +) +from datasets import Dataset +from accelerate.utils import get_max_memory +from packaging import version +from tqdm import tqdm +import torch.distributed as dist +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, +) +from dllm_eval.api.instance import Instance +from dllm_eval.api.model import LM, TemplateLM +from dllm_eval.api.registry import register_model +from dllm_eval.models.utils import get_dtype, configure_pad_token + +try: + from .hts_sampler import HTSSampler +except ImportError: + HTSSampler = None + +eval_logger = logging.getLogger(__name__) +T = TypeVar("T", bound="LM") + + +def add_gumbel_noise(logits, temperature): + if temperature == 0.0: + return logits + logits = logits.to(torch.float32) + noise = torch.rand_like(logits, dtype=torch.float32) + gumbel_noise = (-torch.log(noise)) ** temperature + return logits.exp() / gumbel_noise + + +def get_num_transfer_tokens(mask_index, steps): + mask_num = mask_index.sum(dim=1, keepdim=True) + base = mask_num // steps + remainder = mask_num % steps + num_transfer_tokens = base.expand(-1, steps).clone() + if remainder.sum() > 0: + indices = torch.arange(steps, device=mask_index.device) + mask = indices.unsqueeze(0) < remainder + num_transfer_tokens[mask] += 1 + return num_transfer_tokens.to(torch.int64) + + +@register_model("LLaDA2") +class LLaDA2(TemplateLM): + AUTO_MODEL_CLASS = transformers.AutoModel + _DEFAULT_MAX_LENGTH = 20480 + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + backend: Literal["default", "causal", "seq2seq"] = "causal", + revision: Optional[str] = "main", + subfolder: Optional[str] = None, + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ] = None, + truncation: Optional[bool] = False, + logits_cache: bool = True, + max_length: Optional[int] = None, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + batch_size: Optional[Union[int]] = 1, + max_batch_size: Optional[int] = 64, + trust_remote_code: Optional[bool] = True, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + escape_until:Optional[bool] = False, + prefix_token_id: Optional[int] = None, + parallelize: Optional[bool] = False, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + mc_num: int = 1024, + remasking: str = "low_confidence", + mask_id: int = 156895, + is_check_greedy : bool =True, + assistant_prefix: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + self.mc_num = mc_num + self.mask_id = mask_id + self.remasking = remasking + self.pretrained = pretrained + self.is_check_greedy = is_check_greedy + self.assistant_prefix = assistant_prefix + self.add_bos_token = add_bos_token + self.escape_until = escape_until + if not isinstance(pretrained, str): + eval_logger.warning( + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." + ) + assert not parallelize, ( + "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" + ) + self._model = pretrained + self._device = self._model.device + self._config = self._model.config + gpus = 0 + + else: + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self.accelerator = accelerator + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + if not (parallelize or accelerator.num_processes > 1): + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + self._device = ( + self.accelerator.device + if hasattr(self, "accelerator") + else torch.device(device) + ) + revision = str(revision) + revision = revision + ("/" + subfolder if subfolder is not None else "") + self._get_config( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + ) + self._get_backend( + config=self.config, backend=backend, trust_remote_code=trust_remote_code + ) + self._create_tokenizer( + pretrained, + tokenizer, + revision=revision, + trust_remote_code=trust_remote_code, + use_fast_tokenizer=use_fast_tokenizer, + gguf_file=gguf_file, + add_bos_token=add_bos_token, + ) + if isinstance(pretrained, str): + self._create_model( + pretrained=pretrained, + revision=revision, + dtype=dtype, + trust_remote_code=trust_remote_code, + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + peft=peft, + delta=delta, + autogptq=autogptq, + gptqmodel=gptqmodel, + gguf_file=gguf_file, + **kwargs, + ) + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) + self.add_bos_token = add_bos_token + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." + ) + self._max_length = max_length + self.pretrained = pretrained + self.delta = delta + self.peft = peft + self.revision = revision + self.batch_schedule = 1 + self.batch_sizes = {} + self.max_batch_size = max_batch_size + if str(batch_size).startswith("auto"): + batch_size = batch_size.split(":") + self.batch_size_per_gpu = batch_size[0] + self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 + else: + self.batch_size_per_gpu = int(batch_size) + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + if not (parallelize or autogptq or hasattr(self, "accelerator")): + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." + ) + if gpus > 1: + if hasattr(self, "accelerator") and self.accelerator.num_processes > 1: + if parallelize: + eval_logger.warning( + "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." + ) + elif gpus > self.accelerator.num_processes: + eval_logger.warning( + "WARNING: The number of total system GPUs does not match the number of spawned processes. " + "If you would like to use data parallelism, please launch the script " + "with 'accelerate launch *script*'. " + f"Current run will proceed with {self.accelerator.num_processes} devices." + ) + if self.accelerator.is_local_main_process: + eval_logger.info( + f"Using {gpus} devices with data parallelism" + ) + + self._device = torch.device(f"{self.accelerator.device}") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self._rank = 0 + self._world_size = 1 + else: + self._rank = 0 + self._world_size = 1 + else: + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" + ) + self._rank = 0 + self._world_size = 1 + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + self.is_first_inference = True + + if HTSSampler is not None: + self.hts_sampler = HTSSampler(self.model, self.tokenizer, device=self.device) + eval_logger.info("HTSSampler initialized successfully.") + + @property + def rank(self): + if hasattr(self, "_rank"): + return self._rank + if hasattr(self, "accelerator"): + return self.accelerator.local_process_index + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def world_size(self): + if hasattr(self, "_world_size"): + return self._world_size + if hasattr(self, "accelerator"): + return self.accelerator.num_processes + return int(os.environ.get("WORLD_SIZE", 1)) + + def _get_accelerate_args( + self, + parallelize: Optional[bool] = None, + device_map: Optional[str] = "auto", + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + gpus: Optional[int] = None, + ) -> dict: + num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + if parallelize is None and gpus is not None and gpus > 1: + parallelize = True + args = {} + if parallelize: + max_memory_all_gpus = get_max_memory() + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + max_memory_per_gpu_map = { + device_idx: max_memory_per_gpu for device_idx in range(len(max_memory_all_gpus)) + } if max_memory_per_gpu is not None else {k: v for k, v in max_memory_all_gpus.items()} + if hasattr(self, "accelerator"): + max_memory_per_gpu_map = { + k: v for k, v in max_memory_all_gpus.items() if k % num_local_processes == self.accelerator.process_index % num_local_processes + } + args["max_memory"] = max_memory_per_gpu_map + args["device_map"] = "auto" + args["offload_folder"] = offload_folder + if max_cpu_memory is not None: + args["max_memory"]["cpu"] = max_cpu_memory + eval_logger.info( + f"Model parallel set to True. Max memory per GPU: {args['max_memory']}, Device map: {args['device_map']}" + ) + else: + args["device_map"] = {"": str(self.device)} + eval_logger.info( + f"Model parallel set to False. Device map: {args['device_map']}" + ) + return args + + @property + def config(self): + return self._config + + @property + def model(self): + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length > 1e10: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _get_backend( + self, + config: Union[transformers.PretrainedConfig, transformers.AutoConfig], + backend: Literal["default", "causal", "seq2seq"] = "default", + trust_remote_code: Optional[bool] = False, + ) -> None: + assert backend in ["default", "causal", "seq2seq"] + if backend != "default": + self.backend = backend + eval_logger.info( + f"Overrode HF model backend type, and using type '{self.backend}'" + ) + else: + if ( + getattr(config, "model_type") + in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "seq2seq" + elif ( + getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "causal" + else: + eval_logger.warning( + "HF model type is neither CausalLM nor Seq2SeqLM. Assuming CausalLM." + ) + self.backend = "causal" + + def _get_config( + self, + pretrained: str, + revision: str = "main", + trust_remote_code: bool = False, + gguf_file: Optional[str] = None, + ) -> None: + self._config = transformers.AutoConfig.from_pretrained( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + ) + + def _create_model( + self, + pretrained: str, + revision: Optional[str] = "main", + dtype: Optional[Union[str, torch.dtype]] = "auto", + trust_remote_code: Optional[bool] = False, + parallelize: Optional[bool] = False, + gpus: Optional[int] = None, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + **kwargs, + ) -> None: + if autogptq or gptqmodel: + raise NotImplementedError("Quantization options are not implemented for this custom class.") + model_dtype = get_dtype(dtype) + eval_logger.info(f"Loading model with dtype: {model_dtype}") + model_kwargs = kwargs if kwargs else {} + if not parallelize: + model_kwargs.update( + self._get_accelerate_args( + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + ) + ) + self._model = transformers.AutoModelForCausalLM.from_pretrained( + pretrained, + revision=revision, + torch_dtype=model_dtype, + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + if peft: + from peft import PeftModel + eval_logger.info(f"Loading PEFT model from {peft}") + self._model = PeftModel.from_pretrained(self._model, peft, torch_dtype=model_dtype) + if not parallelize: + self._model = self._model.to(self.device) + self._model = self._model.to(torch.bfloat16) + self._model.eval() + + def _create_tokenizer( + self, + pretrained: Union[str, transformers.PreTrainedModel], + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ], + revision: Optional[str] = "main", + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + gguf_file: Optional[str] = None, + add_bos_token: Optional[bool] = False, + ) -> None: + kwargs = { + "revision": revision, + "trust_remote_code": trust_remote_code, + "use_fast": use_fast_tokenizer + } + if add_bos_token: + kwargs["add_bos_token"] = True + if tokenizer: + if isinstance(tokenizer, str): + self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, **kwargs) + else: + self.tokenizer = tokenizer + else: + model_name = pretrained if isinstance(pretrained, str) else self.model.name_or_path + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, **kwargs) + + def tok_encode( + self, string: str, left_truncate_len=None, add_special_tokens=None + ) -> List[int]: + special_tokens_kwargs = {} + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs["add_special_tokens"] = self.add_bos_token + else: + special_tokens_kwargs["add_special_tokens"] = add_special_tokens + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + return encoding + + def tok_batch_encode( + self, + strings: List[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + add_special_tokens = {"add_special_tokens": self.add_bos_token} if self.backend == "causal" else {} + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len and encoding["input_ids"].size(1) > left_truncate_len: + eval_logger.warning( + f"Left-truncating from {encoding['input_ids'].size(1)} to {left_truncate_len} tokens." + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][:, -left_truncate_len:] + self.tokenizer.padding_side = old_padding_side + return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device) + + def tok_decode(self, tokens, skip_special_tokens=False): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _model_call(self, inps, attn_mask=None, labels=None): + with torch.no_grad(): + if self.backend == "seq2seq": + return self.model(input_ids=inps, attention_mask=attn_mask, labels=labels).logits + else: + return self.model(inps, attention_mask=attn_mask).logits + + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[float]: + raise NotImplementedError + + def loglikelihood(self, requests): + raise NotImplementedError + + def generate_until(self, requests: List[Instance]) -> List[str]: + res = [] + gen_kwargs = requests[0].args[1] + use_hts = gen_kwargs.get("use_hts", False) + + realtime_output = gen_kwargs.get("realtime_output", "realtime_hts_results.jsonl") + baseline_realtime_output = "realtime_baseline_results.jsonl" + + if not use_hts: + bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running Baseline") + ds_data = [{"text": req.args[0]} for req in requests] + ds = Dataset.from_list(ds_data) + + req_idx = 0 + for batch in ds.iter(batch_size=int(self.batch_size)): + contexts = batch["text"] + context_enc, _ = self.tok_batch_encode(contexts) + prompt_length = context_enc.shape[1] + + out_full = self.model.generate( + inputs=context_enc, + steps=gen_kwargs.get("steps", 32), + gen_length=gen_kwargs.get("gen_length", 512), + block_length=gen_kwargs.get("block_length", 32), + temperature=gen_kwargs.get("temperature", 0.7), + eos_early_stop=gen_kwargs.get("eos_early_stop", False), + ) + generated_tokens = out_full[:, prompt_length:] + cont_toks_list = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + + for i, s in enumerate(cont_toks_list): + s = s.strip() + + if not self.escape_until: + until_terms = gen_kwargs.get("until", []) + for term in until_terms: + if len(term) > 0 and term in s: + s = s.split(term)[0] + + orig_req = requests[req_idx] + target_val = getattr(orig_req, "target", None) + if target_val is None or target_val == "N/A": + if "test" in orig_req.doc and "entry_point" in orig_req.doc: + target_val = orig_req.doc["test"] + "\ncheck(" + orig_req.doc["entry_point"] + ")" + else: + target_val = orig_req.doc.get("answer", orig_req.doc.get("solution", "N/A")) + + with open(baseline_realtime_output, "a", encoding="utf-8") as f: + f.write(json.dumps({ + "doc": orig_req.doc, + "target": target_val, + "resps": [[s]], + "prompt": contexts[i] + }, ensure_ascii=False) + "\n") + f.flush() + + res.append(s) + bar.update(1) + req_idx += 1 + bar.close() + + else: + bar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Running HTS+SVF") + for req in requests: + prompt_text = req.args[0] + context_enc, _ = self.tok_batch_encode([prompt_text]) + + p_interval = int(gen_kwargs.get("pruning_interval", 0)) + + final_codes, stats = self.hts_sampler.generate_hts( + prompt_text=prompt_text, + input_ids=context_enc, + initial_N=int(gen_kwargs.get("hts_N", 4)), + final_K=int(gen_kwargs.get("final_K", 1)), + hts_survivor_k=int(gen_kwargs.get("hts_survivor_k", 4)), + hts_mode=gen_kwargs.get("hts_mode", True), + hts_start_pct=float(gen_kwargs.get("hts_start_pct", 0.1)), + hts_end_pct=float(gen_kwargs.get("hts_end_pct", 0.6)), + decay_factor=float(gen_kwargs.get("decay_factor", 1.5)), + pruning_interval=p_interval, + reward_mode=gen_kwargs.get("reward_mode", "svf"), + task_type=gen_kwargs.get("task_type", "code"), + steps=int(gen_kwargs.get("steps", 32)), + gen_length=int(gen_kwargs.get("gen_length", 512)), + block_length=int(gen_kwargs.get("block_length", 32)), + temperature=float(gen_kwargs.get("temperature", 0.7)), + top_p=float(gen_kwargs.get("top_p", 0.95)), + top_k=gen_kwargs.get("top_k", None), + threshold=float(gen_kwargs.get("threshold", 0.85)), + mask_id=self.mask_id, + eos_id=self.eot_token_id + ) + + processed_codes = [] + for code in final_codes: + code = code.strip() + if not self.escape_until: + until_terms = gen_kwargs.get("until", []) + for term in until_terms: + if len(term) > 0 and term in code: + code = code.split(term)[0] + processed_codes.append(code) + + final_choice = processed_codes[0] + res.append(final_choice) + + target_val = getattr(req, "target", None) + if target_val is None or target_val == "N/A": + if "test" in req.doc and "entry_point" in req.doc: + target_val = req.doc["test"] + "\ncheck(" + req.doc["entry_point"] + ")" + else: + target_val = req.doc.get("answer", req.doc.get("solution", "N/A")) + + with open(realtime_output, "a", encoding="utf-8") as f: + all_resps = [[code] for code in processed_codes] + + output_data = { + "doc": req.doc, + "target": target_val, + "resps": all_resps, + "prompt": prompt_text, + "entropy_history": stats.get("entropy_history", []), + "pruning_history": stats.get("pruning_history", []), + "final_scores": stats.get("final_scores", []), + "all_trajectories": stats.get("all_trajectories", []), + "nfe": stats.get("nfe", 0), + "svf_calls": stats.get("svf_calls", 0), + "total_steps": stats.get("total_steps", 0) + } + f.write(json.dumps(output_data, ensure_ascii=False) + "\n") + f.flush() + + bar.update(1) + bar.close() + + return res + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True + ) -> str: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + if self.assistant_prefix: + chat_templated += self.assistant_prefix + return chat_templated \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b229acb5fb21c2f423fcd43a8a235b5a0d12239 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/__init__.py @@ -0,0 +1,19 @@ +from . import ( + LLaDA2, + huggingface, +) +# from .configuration_llada import LLaDAConfig +# from .modeling_llada import LLaDAModelLM + + +try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + import huggingface_hub.constants # type: ignore + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True +except ImportError: + pass + + +# __all__ = ['LLaDAConfig', 'LLaDAModelLM'] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/dummy.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..4702a36cb29809c9dd08c516b99e74e71ffcc166 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/dummy.py @@ -0,0 +1,41 @@ +import random + +from tqdm import tqdm + +from dllm_eval.api.model import LM +from dllm_eval.api.registry import register_model + + +@register_model("dummy") +class DummyLM(LM): + def __init__(self) -> None: + super().__init__() + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + return cls() + + def loglikelihood(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append((-random.random(), False)) + + return res + + def generate_until(self, requests, disable_tqdm: bool = False): + res = [] + + for request in tqdm(requests, disable=disable_tqdm): + res.append("lol") + assert request.arguments[0].strip() != "" + + return res + + def loglikelihood_rolling(self, requests, disable_tqdm: bool = False): + res = [] + + for _ in tqdm(requests, disable=disable_tqdm): + res.append(-random.random()) + + return res diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/hts_sampler.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/hts_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4221dc085672e7c41bfb6084224fa8ff883c38e9 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/hts_sampler.py @@ -0,0 +1,325 @@ +import torch +import torch.nn.functional as F +import numpy as np +from .verifier import CodeVerifier +import logging +import re +import math + +logger = logging.getLogger(__name__) + +class HTSSampler: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + self.verifier = CodeVerifier(model, tokenizer, device) + + def _get_num_transfer_tokens(self, block_length, steps): + if steps == 0: return torch.tensor([], dtype=torch.int64) + base = block_length // steps + remainder = block_length % steps + num_transfer_tokens = torch.full((steps,), base, dtype=torch.int64) + num_transfer_tokens[:remainder] += 1 + return num_transfer_tokens + + def _sample_with_temperature(self, logits, temperature, top_k, top_p): + logits = logits.to(torch.float32) + + orig_probs = torch.softmax(logits, dim=-1) + x0_p, _ = torch.max(orig_probs, dim=-1) + + if temperature > 0.0: + noise = torch.rand_like(logits, dtype=torch.float32) + gumbel_noise = -torch.log(-torch.log(noise + 1e-10) + 1e-10) + logits = logits / temperature + gumbel_noise + + if top_k is not None and top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = -float('Inf') + + x0 = torch.argmax(logits, dim=-1) + + return x0, x0_p + + def _safe_scalar(self, val): + if isinstance(val, torch.Tensor): + if val.numel() > 1: return val.mean().item() + return val.item() + return float(val) + + def _analyze_structure(self, text, task_type="code"): + score = 0.0 + stripped = text.strip() + if task_type == "code": + if len(stripped) < 5: return -0.1 + keywords = ["return", "print", "yield", "lambda", "class ", "def "] + if any(k in stripped for k in keywords): score += 0.05 + if ":" in stripped: score += 0.02 + if " " in text: score += 0.03 + elif task_type == "math": + if "\\boxed{" in stripped: score += 0.1 + if "The answer is" in stripped: score += 0.05 + if len(stripped) < 10: return -0.1 + if "Step" in stripped and stripped.count("Step") > 15: score -= 0.2 + return score + + def _chunked_forward(self, x, chunk_size=32, slice_start=None): + total_batch = x.shape[0] + logits_list = [] + for i in range(0, total_batch, chunk_size): + end_idx = min(i + chunk_size, total_batch) + sub_x = x[i:end_idx] + with torch.no_grad(): + outputs = self.model(input_ids=sub_x) + sub_logits = outputs.logits + if slice_start is not None: + s_start = slice_start if slice_start >= 0 else sub_logits.shape[1] + slice_start + sub_logits = sub_logits[:, s_start:, :] + logits_list.append(sub_logits.detach().clone()) + return torch.cat(logits_list, dim=0) + + def _branch_and_resample(self, x, conf_scores, survivor_indices, target_width, mask_id, + prompt_length, resample_window=5, task_type="code"): + num_survivors = len(survivor_indices) + if num_survivors == 0: return x[:target_width].clone(), conf_scores[:target_width].clone() + + if task_type == "math": resample_window = 12 + elif task_type == "reasoning": resample_window = 10 + elif task_type == "code": resample_window = 6 + + base_repeat = target_width // num_survivors + remainder = target_width % num_survivors + new_x_list = [] + new_conf_list = [] + + for i in range(num_survivors): + count = base_repeat + (1 if i < remainder else 0) + if count == 0: continue + + survivor_x = x[survivor_indices[i]] + survivor_conf = conf_scores[survivor_indices[i]] + + new_x_list.append(survivor_x.unsqueeze(0)) + new_conf_list.append(survivor_conf.unsqueeze(0)) + + if count > 1: + gen_part = survivor_x[prompt_length:] + gen_conf = survivor_conf[prompt_length:] + non_mask_indices = (gen_part != mask_id).nonzero(as_tuple=True)[0] + + for _ in range(count - 1): + perturbed_x = survivor_x.clone() + perturbed_conf = survivor_conf.clone() + + if len(non_mask_indices) > 0: + pool_size = min(resample_window * 2, len(non_mask_indices)) + current_token_confs = gen_conf[non_mask_indices] + + _, candidate_indices = torch.topk(current_token_confs, k=pool_size, largest=False) + + num_to_perturb = min(resample_window, pool_size) + rand_indices = torch.randperm(pool_size, device=self.device)[:num_to_perturb] + selected_sub_indices = candidate_indices[rand_indices] + + target_indices_in_x = prompt_length + non_mask_indices[selected_sub_indices] + perturbed_x[target_indices_in_x] = mask_id + perturbed_conf[target_indices_in_x] = 0.0 + + new_x_list.append(perturbed_x.unsqueeze(0)) + new_conf_list.append(perturbed_conf.unsqueeze(0)) + + return torch.cat(new_x_list, dim=0), torch.cat(new_conf_list, dim=0) + + @torch.no_grad() + def generate_hts(self, prompt_text, input_ids, problem_data=None, + initial_N=1, final_K=1, survivor_K=None, + prune_step_pct=0.0, reward_mode="confidence", + temperature=0.7, block_length=32, steps=64, gen_length=1024, + top_p=0.95, top_k=None, minimal_topk=1, threshold=0.9, + eos_id=156892, mask_id=156895, + hts_mode=False, hts_start_pct=0.1, hts_end_pct=0.6, decay_factor=1.5, + hts_survivor_k=4, task_type="code", until=None, pruning_interval=0): + + input_ids = input_ids.to(self.device) + if input_ids.shape[0] == 1: input_ids = input_ids.repeat(initial_N, 1) + + schedule_map = {} + ts_start, tr_end = 0, 0 + if not hts_mode: + final_K_list = [final_K] if not isinstance(final_K, list) else final_K + prune_pct_list = [prune_step_pct] if not isinstance(prune_step_pct, list) else prune_step_pct + survivor_K_list = final_K_list if survivor_K is None else ([survivor_K] if not isinstance(survivor_K, list) else survivor_K) + if len(survivor_K_list) < len(final_K_list): survivor_K_list.extend(final_K_list[len(survivor_K_list):]) + for pct, width, parents in zip(prune_pct_list, final_K_list, survivor_K_list): + if pct > 0: + s = int(steps * pct) + schedule_map[s] = (width, parents) + else: + final_K_list = [final_K] if not isinstance(final_K, int) else [final_K] + ts_start, tr_end = int(steps * hts_start_pct), int(steps * hts_end_pct) + + steps = min(steps, gen_length // minimal_topk) + prompt_length = input_ids.shape[1] + num_blocks = (prompt_length + gen_length + block_length - 1) // block_length + total_length = num_blocks * block_length + + x = torch.full((initial_N, total_length), mask_id, dtype=torch.long, device=self.device) + x[:, :prompt_length] = input_ids.clone() + + conf_scores = torch.zeros((initial_N, total_length), dtype=torch.float32, device=self.device) + conf_scores[:, :prompt_length] = 1.0 + + prefill_blocks = prompt_length // block_length + num_gen_blocks = max(1, num_blocks - prefill_blocks) + current_bsz = initial_N + + next_allowed_pruning_step = ts_start if hts_mode else 0 + + stats = { + "initial_n": initial_N, "final_k": final_K_list[-1], + "pruning_history": [], "entropy_history": [], "nfe": 0.0, + "svf_calls": 0, "final_scores": [], "total_steps": steps + } + + for num_block in range(prefill_blocks, num_blocks): + window_end = (num_block + 1) * block_length + schedule = self._get_num_transfer_tokens(block_length, steps) + + for step in range(steps): + cur_x = x[:current_bsz, :window_end] + + perform_pruning = False + num_parents_to_select = 0 + + if hts_mode and step >= next_allowed_pruning_step and step < tr_end: + target_width = max(final_K_list[-1], math.ceil(initial_N * (decay_factor ** -(step - ts_start)))) + if current_bsz > target_width: + perform_pruning = True + num_parents_to_select = hts_survivor_k + elif not hts_mode and step in schedule_map: + target_width, num_parents_to_select = schedule_map[step] + if current_bsz > target_width: perform_pruning = True + + if perform_pruning: + stats["nfe"] += current_bsz + stats["svf_calls"] += current_bsz + + gen_logits = self._chunked_forward(cur_x, chunk_size=16, slice_start=prompt_length) + rough_ids = torch.argmax(gen_logits, dim=-1) + rough_codes_snippet = self.tokenizer.batch_decode(rough_ids, skip_special_tokens=True) + candidates = [] + for i in range(current_bsz): + full_code = rough_codes_snippet[i] + s = self._safe_scalar(self.verifier.get_reward(prompt_text, full_code, mode=reward_mode, problem_data=problem_data, current_logits=gen_logits[i] if reward_mode != "svf" else None, task_type=task_type)) + s += self._analyze_structure(full_code, task_type=task_type) + clean_content = full_code.strip().replace(" ", "").replace("\n", "") + candidates.append({'score': s, 'idx': i, 'key': hash(clean_content[:200] + clean_content[-200:])}) + + stats["pruning_history"].append({"step": step, "scores": [c['score'] for c in candidates]}) + candidates.sort(key=lambda x: x['score'], reverse=True) + + selected_indices, seen_keys = [], set() + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['key'] not in seen_keys: + selected_indices.append(cand['idx']); seen_keys.add(cand['key']) + + if len(selected_indices) < num_parents_to_select: + for cand in candidates: + if len(selected_indices) >= num_parents_to_select: break + if cand['idx'] not in selected_indices: selected_indices.append(cand['idx']) + + top_indices = torch.tensor(selected_indices, device=self.device) + x, conf_scores = self._branch_and_resample(x, conf_scores, top_indices, target_width, mask_id, prompt_length, task_type=task_type) + + current_bsz = target_width + cur_x = x[:current_bsz, :window_end] + next_allowed_pruning_step = step + 1 + pruning_interval + + active_mask = cur_x[:, -block_length:] == mask_id + if active_mask.sum() == 0: break + + stats["nfe"] += current_bsz + + active_logits = self._chunked_forward(cur_x, chunk_size=32, slice_start=-block_length) + + with torch.no_grad(): + if len(stats["entropy_history"]) < 32: + probs_for_stats = torch.softmax(active_logits.float(), dim=-1) + entropy_per_branch = (-(probs_for_stats * torch.log(probs_for_stats + 1e-10)).sum(dim=-1).mean(dim=-1)).cpu().numpy().tolist() + stats["entropy_history"].append(entropy_per_branch) + + x0, x0_p = self._sample_with_temperature(active_logits, temperature, top_k, top_p) + + num_transfer = schedule[step].item() + confidence = torch.where(active_mask, x0_p, -torch.inf) + transfer_idx = torch.zeros_like(x0, dtype=torch.bool) + + for b in range(current_bsz): + k_transfer = min(num_transfer, active_mask[b].sum().item()) + active_indices = torch.where(active_mask[b])[0] + if (confidence[b] > threshold).sum().item() >= k_transfer: + conf_indices = torch.where((confidence[b] > threshold) & active_mask[b])[0]; transfer_idx[b, conf_indices] = True + elif len(active_indices) > 0: + _, topk_indices = torch.topk(confidence[b][active_indices], k=min(k_transfer, len(active_indices))); transfer_idx[b, active_indices[topk_indices]] = True + + if transfer_idx.any(): + cur_x[:, -block_length:][transfer_idx] = x0[transfer_idx] + conf_scores[:current_bsz, window_end-block_length:window_end][transfer_idx] = x0_p[transfer_idx] + + if task_type in ["math", "reasoning"]: + for b in range(current_bsz): + gen_span = cur_x[b, prompt_length:window_end] + text_snippet = self.tokenizer.decode(gen_span, skip_special_tokens=True) + should_stop = False + if task_type == "reasoning" and ("###" in text_snippet): + should_stop = True + if task_type == "math" and ("\\boxed{" in text_snippet and "}" in text_snippet.split("\\boxed{")[-1]): + should_stop = True + + if should_stop: + non_mask_indices = (gen_span != mask_id).nonzero(as_tuple=True)[0] + if len(non_mask_indices) > 0: + last_idx = non_mask_indices[-1].item() + if last_idx + 1 < len(gen_span): + gen_span[last_idx + 1:] = eos_id + cur_x[b, prompt_length:window_end] = gen_span + if window_end < total_length: + x[b, window_end:] = eos_id + conf_scores[b, window_end:] = 1.0 + + for b in range(current_bsz): + gen_window = cur_x[b, prompt_length:window_end] + eos_indices = (gen_window == eos_id).nonzero(as_tuple=True)[0] + if len(eos_indices) > 0: + first_eos_idx = eos_indices[0].item() + if first_eos_idx + 1 < len(gen_window): + gen_window[first_eos_idx + 1:] = eos_id + cur_x[b, prompt_length:window_end] = gen_window + + x = x[:current_bsz] + x[:, :window_end] = cur_x + + stats["nfe"] = int(round(stats["nfe"])) + + final_gen_tokens = x[:current_bsz, prompt_length:] + final_codes = self.tokenizer.batch_decode(final_gen_tokens, skip_special_tokens=True) + final_candidates = [] + + stats["svf_calls"] += len(final_codes) + + for i in range(len(final_codes)): + txt = final_codes[i] + if until: + for term in until: + if term in txt: txt = txt.split(term)[0] + s = self._safe_scalar(self.verifier.get_reward(prompt_text, txt, mode=reward_mode, task_type=task_type)) + s += self._analyze_structure(txt, task_type) + final_candidates.append({'resp': txt, 'score': s}) + + final_candidates.sort(key=lambda x: x['score'], reverse=True) + stats["final_scores"] = [c['score'] for c in final_candidates] + stats["all_trajectories"] = [{"rank": i+1, "resp": c['resp'], "score": c['score']} for i, c in enumerate(final_candidates)] + + return [c['resp'] for c in final_candidates], stats \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/huggingface.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6e1e99e20aeed5b20f7cd2d7a8f9b76155330a --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/huggingface.py @@ -0,0 +1,1489 @@ +import copy +import logging +import os +from datetime import timedelta +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import jinja2 +import torch +import torch.nn.functional as F +import transformers +from accelerate import ( + Accelerator, + InitProcessGroupKwargs, + find_executable_batch_size, +) +from accelerate.utils import get_max_memory +from huggingface_hub import HfApi +from packaging import version +from peft import PeftModel +from peft import __version__ as PEFT_VERSION +from tqdm import tqdm +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, +) + +from dllm_eval import utils +from dllm_eval.api.instance import Instance +from dllm_eval.api.model import TemplateLM +from dllm_eval.api.registry import register_model +from dllm_eval.models.utils import ( + Collator, + clear_torch_cache, + configure_pad_token, + get_dtype, + handle_stop_sequences, + pad_and_concat, + stop_sequences_criteria, +) + + +eval_logger = logging.getLogger(__name__) + + +@register_model("hf-auto", "hf", "huggingface") +class HFLM(TemplateLM): + """ + An abstracted Huggingface model class. Enables usage with both models of + `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. + + Supports data-parallel multi-GPU with HF Accelerate. + """ + + AUTO_MODEL_CLASS = None + _DEFAULT_MAX_LENGTH = 2048 + + def __init__( + self, + pretrained: Union[str, transformers.PreTrainedModel], + backend: Literal["default", "causal", "seq2seq"] = "default", + # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) + revision: Optional[str] = "main", + subfolder: str = "", + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ] = None, + truncation: Optional[bool] = False, + logits_cache: bool = True, + max_length: Optional[int] = None, + device: Optional[str] = "cuda", + dtype: Optional[Union[str, torch.dtype]] = "auto", + softmax_dtype: Optional[Union[str, torch.dtype]] = None, + batch_size: Optional[Union[int, str]] = 1, + max_batch_size: Optional[int] = 64, + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + add_bos_token: Optional[bool] = False, + prefix_token_id: Optional[int] = None, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + parallelize: Optional[bool] = False, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__() + # optionally: take in an already-initialized transformers.PreTrainedModel + if not isinstance(pretrained, str): + eval_logger.warning( + "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." + ) + assert not parallelize, ( + "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" + ) + self._model = pretrained + self._device = self._model.device + self._config = self._model.config + gpus = 0 + + else: + assert isinstance(device, str) + assert isinstance(pretrained, str) + assert isinstance(batch_size, (int, str)) + + gpus = torch.cuda.device_count() + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + if accelerator.num_processes > 1: + self.accelerator = accelerator + + if "npu" in accelerator.device.type: + gpus = torch.npu.device_count() + + # using one process with no model parallelism + if not (parallelize or accelerator.num_processes > 1): + # use user-passed device + device_list = set( + ["cuda", "cpu"] + + [f"cuda:{i}" for i in range(gpus)] + + ["mps", "mps:0"] + + [f"npu:{i}" for i in range(gpus)] + ) + if device and device in device_list: + self._device = torch.device(device) + eval_logger.info(f"Using device '{device}'") + if device in ("mps", "mps:0") and version.parse( + torch.__version__ + ) < version.parse("2.1"): + raise RuntimeError( + f"mps requires torch >= 2.1. You have {torch.__version__}" + ) + else: + eval_logger.info("Device not specified") + eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") + self._device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + else: # Parallelism managed by accelerate + if device != "cuda": + eval_logger.info( + f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." + ) + # TODO: include in warning that `load_in_8bit` etc. affect this too + self._device = ( + self.accelerator.device + if hasattr(self, "accelerator") + else torch.device(device) + ) + + revision = str(revision) # cast to string if not already one + + self._get_config( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + subfolder=subfolder, + ) + + # determine which of 'causal' and 'seq2seq' backends to use for HF models + self._get_backend( + config=self.config, backend=backend, trust_remote_code=trust_remote_code + ) + + # load tokenizer so we know tokenizer vocabulary size before loading model and PEFT + self._create_tokenizer( + pretrained, + tokenizer, + revision=revision, + subfolder=subfolder, + trust_remote_code=trust_remote_code, + use_fast_tokenizer=use_fast_tokenizer, + gguf_file=gguf_file, + add_bos_token=add_bos_token, + ) + + # if we passed `pretrained` as a string, initialize our model now + if isinstance(pretrained, str): + self._create_model( + pretrained=pretrained, + revision=revision, + dtype=dtype, + trust_remote_code=trust_remote_code, + parallelize=parallelize, + gpus=gpus, + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + peft=peft, + delta=delta, + autogptq=autogptq, + gptqmodel=gptqmodel, + gguf_file=gguf_file, + quantization_config=getattr(self.config, "quantization_config", None), + subfolder=subfolder, + **kwargs, + ) + + # access self._model through self.model property outside this method + if isinstance(self.model, torch.nn.Module): + self.model.eval() + self.model.tie_weights() + + self.truncation = truncation + self.logits_cache = logits_cache + self.vocab_size = self.tokenizer.vocab_size + # select (or create) a pad token to use + self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) + + self.add_bos_token = add_bos_token + if "gemma" in getattr(self.config, "model_type", ""): + self.add_bos_token = True + eval_logger.info( + f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." + ) + + self._max_length = max_length + self.pretrained = pretrained + self.delta = delta + self.peft = peft + self.revision = revision + self.batch_schedule = 1 + self.batch_sizes = {} + self.max_batch_size = max_batch_size + self.softmax_dtype = ( + get_dtype(softmax_dtype) if softmax_dtype is not None else None + ) + + if str(batch_size).startswith("auto"): + batch_size = batch_size.split(":") + self.batch_size_per_gpu = batch_size[0] + self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 + else: + self.batch_size_per_gpu = int(batch_size) + + if isinstance(pretrained, str): + if gpus >= 1 or str(self.device) == "mps": + # TODO: can remove this whole snippet except in the mps case, perhaps? + if not (parallelize or autogptq or hasattr(self, "accelerator")): + # place model onto device requested manually, + # if not using HF Accelerate or device_map + # or any other option that preloads model onto device + try: + self.model.to(self.device) + except ValueError: + eval_logger.debug( + "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." + ) + # multigpu data-parallel support when launched with accelerate + if gpus > 1: + if accelerator.num_processes > 1: + if parallelize: + eval_logger.warning( + "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." + ) + elif gpus > accelerator.num_processes: + eval_logger.warning( + "WARNING: The number of total system GPUs does not match the number of spawned processes. " + "If you would like to use data parallelism, please launch the script " + "with 'accelerate launch *script*'. " + f"Current run will proceed with {accelerator.num_processes} devices." + ) + if self.accelerator.is_local_main_process: + eval_logger.info( + f"Using {gpus} devices with data parallelism" + ) + + self._device = torch.device(f"{accelerator.device}") + self.accelerator = accelerator + + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + # if we aren't launching via accelerate, ditch + self._rank = 0 + self._world_size = 1 + else: + # if a PreTrainedModel was passed into HFLM, we forgo distributed setup. + eval_logger.warning( + "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" + ) + self._rank = 0 + self._world_size = 1 + + self.custom_prefix_token_id = prefix_token_id + if prefix_token_id is not None: + eval_logger.info( + f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" + ) + + def _get_accelerate_args( + self, + parallelize: Optional[bool] = None, + device_map: Optional[str] = "auto", + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + gpus: Optional[int] = None, + ) -> dict: + """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" + num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes + if ( + num_machines == 0 + and hasattr(self, "accelerator") + and self.accelerator is not None + ): + eval_logger.info( + "We are not in a distributed setting for accelerate. Setting model_parallel to False." + ) + parallelize = False + + if parallelize is None: + # If parallelism is unset by the user, we automatically assign model parallelism + # if enough extra GPUs are available + max_memory_all_gpus = get_max_memory() + # We just want gpu, not cpu, max memory + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + parallelize = bool(num_local_processes < len(max_memory_all_gpus)) + eval_logger.info( + f"Setting model parallel to {parallelize} since " + f"the number of local processes is {num_local_processes} " + f"and the number of GPUs is {len(max_memory_all_gpus)}" + ) + + args = {} + if parallelize: # Model parallelism will be used + max_memory = {} + if max_memory_per_gpu is not None: # Using the provided memory requirements + max_memory_per_gpu_map = { + device_idx: max_memory_per_gpu for device_idx in range(gpus) + } + else: # Estimating the possible memory requirements + max_memory_all_gpus = get_max_memory() + if "cpu" in max_memory_all_gpus: + del max_memory_all_gpus["cpu"] + if not hasattr(self, "accelerator"): + max_memory_per_gpu_map = { + k: v for k, v in max_memory_all_gpus.items() + } + else: + # use only 1 / num_processes of the GPUs if we are running under accelerate launch + max_memory_per_gpu_map = { + k: v + for k, v in max_memory_all_gpus.items() + if k % num_local_processes + == (self.accelerator.process_index % num_local_processes) + } + args["max_memory"] = max_memory_per_gpu_map + args["device_map"] = "auto" if device_map is None else device_map + eval_logger.info( + f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}" + ) + + if max_cpu_memory is not None: + max_memory["cpu"] = max_cpu_memory + + args["offload_folder"] = offload_folder + elif ( + device_map is None + ): # No model parallelism, we use the default provided device for our model + if hasattr(self, "accelerator"): + device_map = {"": f"{self.accelerator.device}"} + else: + device_map = {"": str(self.device)} + args["max_memory"] = None + args["device_map"] = device_map + eval_logger.info( + f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}" + ) + else: + args["max_memory"] = None + args["device_map"] = None + eval_logger.info("Model parallel was set to False.") + + return args + + @property + def config(self): + # return the associated transformers.AutoConfig for the given pretrained model. + return self._config + + @property + def model(self): + # returns the model, unwrapping it if using Accelerate + if hasattr(self, "accelerator"): + return self.accelerator.unwrap_model(self._model) + else: + return self._model + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def prefix_token_id(self): + # it is used as prefix for loglikelihood + if self.custom_prefix_token_id is not None: + return self.custom_prefix_token_id + if self.tokenizer.bos_token_id is not None: + return self.tokenizer.bos_token_id + return self.tokenizer.eos_token_id + + @property + def max_length(self): + if self._max_length: # if max length manually set, return it + return self._max_length + seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") + for attr in seqlen_config_attrs: + if hasattr(self.model.config, attr): + return getattr(self.model.config, attr) + if hasattr(self.tokenizer, "model_max_length"): + if self.tokenizer.model_max_length == 1000000000000000019884624838656: + return self._DEFAULT_MAX_LENGTH + return self.tokenizer.model_max_length + return self._DEFAULT_MAX_LENGTH + + @property + def max_gen_toks(self) -> int: + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu + + @property + def device(self): + return self._device + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def tokenizer_name(self) -> str: + return self.tokenizer.name_or_path.replace("/", "__") + + def _get_backend( + self, + config: Union[transformers.PretrainedConfig, transformers.AutoConfig], + backend: Literal["default", "causal", "seq2seq"] = "default", + trust_remote_code: Optional[bool] = False, + ) -> None: + """ + Helper method during initialization. + Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used. + sets `self.AUTO_MODEL_CLASS` appropriately if not already set. + + **If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM, + user must set `self.backend` to be either "causal" or "seq2seq" manually!** + """ + + assert backend in ["default", "causal", "seq2seq"] + + if backend != "default": + # if we've settled on non-default backend, use that manually + if backend == "causal": + self.backend = backend + elif backend == "seq2seq": + self.backend = backend + eval_logger.info( + f"Overrode HF model backend type, and using type '{self.backend}'" + ) + else: + # determine and use the default HF backend for this model, based on its config + metadata. + if ( + getattr(config, "model_type") + in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + ): + # first check if model type is listed under seq2seq models, since some + # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. + # these special cases should be treated as seq2seq models. + self.backend = "seq2seq" + eval_logger.debug(f"Using model type '{self.backend}'") + elif ( + getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + self.backend = "causal" + eval_logger.debug(f"Using model type '{self.backend}'") + else: + if not trust_remote_code: + eval_logger.warning( + "HF model type is neither marked as CausalLM or Seq2SeqLM. \ + This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." + "Setting backend to causal" + ) + # if model type is neither in HF transformers causal or seq2seq model registries + # then we default to assuming AutoModelForCausalLM + self.backend = "causal" + eval_logger.info( + f"Model type cannot be determined. Using default model type '{self.backend}'" + ) + + if self.AUTO_MODEL_CLASS is None: + if self.backend == "causal": + self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM + elif self.backend == "seq2seq": + self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM + + def _get_config( + self, + pretrained: str, + revision: str = "main", + trust_remote_code: bool = False, + gguf_file: Optional[str] = None, + subfolder: str = "", + ) -> None: + """Return the model config for HuggingFace models""" + self._config = transformers.AutoConfig.from_pretrained( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + subfolder=subfolder, + ) + + def _create_model( + self, + pretrained: str, + revision: Optional[str] = "main", + dtype: Optional[Union[str, torch.dtype]] = "auto", + trust_remote_code: Optional[bool] = False, + # arguments used for splitting a model across GPUs naively. + # only used if `parallelize=True`. + # (accelerate naive PP (device_map) options) + parallelize: Optional[bool] = False, + gpus: Optional[int] = None, + max_memory_per_gpu: Optional[Union[int, str]] = None, + max_cpu_memory: Optional[Union[int, str]] = None, + offload_folder: Optional[str] = "./offload", + # PEFT, delta weights and quantization options + peft: Optional[str] = None, + delta: Optional[str] = None, + autogptq: Optional[Union[bool, str]] = False, + gptqmodel: Optional[bool] = False, + gguf_file: Optional[str] = None, + quantization_config: Optional[Dict[str, Any]] = None, + subfolder: str = "", + **kwargs, + ) -> None: + """ + Initializes an HF or HF-compatible PreTrainedModel from scratch + inside HFLM, using the kwargs passed into self.__init__(). + + Also handles functionality such as AutoGPTQ usage and PEFT wrapping. + + For future similar extensions to AutoGPTQ that are not core to HF's ecosystem, + (such as PyTorch models that are nearly, but not quite, fully mirroring + HF's public interface relied on in this HFLM class) + please consider subclassing HFLM and overriding this and other methods as needed. + """ + + model_kwargs = kwargs if kwargs else {} + + model_kwargs.update( + self._get_accelerate_args( + parallelize=parallelize, + device_map=kwargs.get("device_map", None), + max_memory_per_gpu=max_memory_per_gpu, + max_cpu_memory=max_cpu_memory, + offload_folder=offload_folder, + gpus=gpus, + ) + ) + + if not autogptq and not gptqmodel: + if model_kwargs.get("load_in_4bit", None): + assert transformers.__version__ >= "4.30.0", ( + "load_in_4bit requires transformers >= 4.30.0" + ) + if transformers.__version__ >= "4.30.0": + if model_kwargs.get("load_in_4bit", None): + if model_kwargs.get("bnb_4bit_compute_dtype", None): + model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( + model_kwargs["bnb_4bit_compute_dtype"] + ) + + self._model = self.AUTO_MODEL_CLASS.from_pretrained( + pretrained, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + gguf_file=gguf_file, + quantization_config=quantization_config, + subfolder=subfolder, + **model_kwargs, + ) + else: + if autogptq and gptqmodel: + raise ValueError( + "Cannot use both 'autogptq' and 'gptqmodel' options at the same time." + ) + + if autogptq: + try: + from auto_gptq import AutoGPTQForCausalLM + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load auto_gptq, but auto-gptq is not installed ", + "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", + ) + + self._model = AutoGPTQForCausalLM.from_quantized( + pretrained, + trust_remote_code=trust_remote_code, + model_basename=None if autogptq is True else Path(autogptq).stem, + use_safetensors=True + if autogptq is True + else autogptq.endswith(".safetensors"), + **model_kwargs, + ) + + if gptqmodel: + try: + from gptqmodel import GPTQModel + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load gptqmodel, but gptqmodel is not installed ", + "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`", + ) + + self._model = GPTQModel.from_quantized( + pretrained, trust_remote_code=trust_remote_code, **model_kwargs + ) + + if peft and delta: + raise ValueError( + "Cannot use both 'peft' and 'delta' options at the same time." + ) + + if peft: + if model_kwargs.get("load_in_4bit", None): + if version.parse(PEFT_VERSION) < version.parse("0.4.0"): + raise AssertionError("load_in_4bit requires peft >= 0.4.0") + if self._model.config.vocab_size != len(self.tokenizer): + # resize model for LoRAs with added tokens + eval_logger.info( + f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." + ) + self._model.resize_token_embeddings(len(self.tokenizer)) + self._model = PeftModel.from_pretrained( + self._model, peft, revision=revision + ) + elif delta: + if autogptq: + eval_logger.warning( + "Delta weights might trigger unexpected behavior when used with AutoGPTQ." + ) + _model_delta = self.AUTO_MODEL_CLASS.from_pretrained( + delta, + revision=revision, + torch_dtype=get_dtype(dtype), + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + for name, param in self._model.state_dict().items(): + try: + param.data += _model_delta.state_dict()[name] + except KeyError: + raise KeyError(f"Delta model is missing weights for layer: {name}") + except Exception as e: + raise RuntimeError( + f"Failed to add delta weights to layer {name}. Error: {e}" + ) + + del _model_delta + + return None + + def _create_tokenizer( + self, + pretrained: Union[str, transformers.PreTrainedModel], + tokenizer: Optional[ + Union[ + str, + transformers.PreTrainedTokenizer, + transformers.PreTrainedTokenizerFast, + ] + ], + revision: Optional[str] = "main", + trust_remote_code: Optional[bool] = False, + use_fast_tokenizer: Optional[bool] = True, + gguf_file: Optional[str] = None, + add_bos_token: Optional[bool] = False, + subfolder: Optional[str] = "", + ) -> None: + """ + Helper method during initialization. + + Create a tokenizer object corresponding to the correct + tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. + """ + kwargs = { + "revision": revision, + "trust_remote_code": trust_remote_code, + } + + # gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param + if gguf_file is not None: + kwargs["gguf_file"] = gguf_file + else: + kwargs["use_fast"] = use_fast_tokenizer + + if add_bos_token: + kwargs["add_bos_token"] = True + + if subfolder: + kwargs["subfolder"] = subfolder + + if tokenizer: + if isinstance(tokenizer, str): + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + tokenizer, **kwargs + ) + else: + assert isinstance( + tokenizer, transformers.PreTrainedTokenizer + ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) + self.tokenizer = tokenizer + else: + # Get tokenizer based on 'pretrained' + if isinstance(pretrained, str): + model_name = pretrained + else: + # get the HF hub name via accessor on model + model_name = self.model.name_or_path + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, **kwargs + ) + return None + + def _detect_batch_size(self, requests=None, pos: int = 0): + if requests: + _, context_enc, continuation_enc = requests[pos] + max_length = len( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + ) + max_context_enc = len(context_enc[-(self.max_length + 1) :]) + max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) + else: + max_length = self.max_length + max_context_enc = max_length + max_cont_enc = max_length + + # if OOM, then halves batch_size and tries again + @find_executable_batch_size(starting_batch_size=self.max_batch_size) + def forward_batch(batch_size): + if self.backend == "seq2seq": + length = max(max_context_enc, max_cont_enc) + batched_conts = torch.ones( + (batch_size, length), device=self.device + ).long() + test_batch = torch.ones((batch_size, length), device=self.device).long() + call_kwargs = { + "attn_mask": test_batch, + "labels": batched_conts, + } + else: + call_kwargs = {} + test_batch = torch.ones( + (batch_size, max_length), device=self.device + ).long() + for _ in range(5): + out = F.log_softmax( # noqa: F841 + self._model_call(test_batch, **call_kwargs), + dim=-1, + dtype=self.softmax_dtype, + ) + + return batch_size + + try: + batch_size = forward_batch() + except RuntimeError as e: + if "No executable batch size found" in str(e): + batch_size = 1 + else: + raise + + if self.world_size > 1: + # if multi-GPU, always take minimum over all selected batch sizes + max_rnk_bs = torch.tensor([batch_size], device=self.device) + gathered = ( + self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() + ) + batch_size = min(gathered) + clear_torch_cache() + return batch_size + + clear_torch_cache() + return batch_size + + def tok_encode( + self, string: str, left_truncate_len=None, add_special_tokens=None + ) -> List[int]: + """ """ + # default for None - empty dict, use predefined tokenizer param + # used for all models except for CausalLM or predefined value + special_tokens_kwargs = {} + + # by default for CausalLM - false or self.add_bos_token is set + if add_special_tokens is None: + if self.backend == "causal": + special_tokens_kwargs = { + "add_special_tokens": False or self.add_bos_token + } + # otherwise the method explicitly defines the value + else: + special_tokens_kwargs = {"add_special_tokens": add_special_tokens} + + encoding = self.tokenizer.encode(string, **special_tokens_kwargs) + + # left-truncate the encoded context to be at most `left_truncate_len` tokens long + if left_truncate_len: + encoding = encoding[-left_truncate_len:] + + return encoding + + def tok_batch_encode( + self, + strings: List[str], + padding_side: str = "left", + left_truncate_len: int = None, + truncation: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = padding_side + + add_special_tokens = {} + if self.backend == "causal": + add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + + encoding = self.tokenizer( + strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + **add_special_tokens, + ) + if left_truncate_len: + original_lengths = encoding["input_ids"].size(1) + if original_lengths > left_truncate_len: + eval_logger.warn( + f"Left truncation applied. Original sequence length was {original_lengths}, " + f"truncating to last {left_truncate_len} tokens. Some content will be lost.", + ) + encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] + encoding["attention_mask"] = encoding["attention_mask"][ + :, -left_truncate_len: + ] + self.tokenizer.padding_side = old_padding_side + + return encoding["input_ids"], encoding["attention_mask"] + + def tok_decode(self, tokens, skip_special_tokens=True): + return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + + def _model_call(self, inps, attn_mask=None, labels=None): + """ + :param inps: torch.Tensor + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape + [batch, sequence_ctx]. the size of sequence may vary from call to call + :param attn_mask: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :param labels: torch.Tensor, optional + A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed + (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM + :return + A torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model's decoder + """ + with torch.no_grad(): + if attn_mask is not None or labels is not None: + assert attn_mask is not None and labels is not None + assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM + return self.model( + input_ids=inps, attention_mask=attn_mask, labels=labels + ).logits + else: + assert self.AUTO_MODEL_CLASS in ( + transformers.AutoModelForCausalLM, + transformers.AutoModelForVision2Seq, + ) + return self.model(inps).logits + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + # build stopping criteria + stopping_criteria = stop_sequences_criteria( + self.tokenizer, stop, context.shape[1], context.shape[0] + ) + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=True, + **generation_kwargs, + ) + + def _select_cont_toks( + self, logits: torch.Tensor, contlen: int = None, inplen: int = None + ) -> torch.Tensor: + if self.backend == "causal": + assert contlen and inplen, ( + "Must pass input len and cont. len to select scored logits for causal LM" + ) + # discard right-padding. + # also discard the input/context tokens. we'll only score continuations. + logits = logits[inplen - contlen : inplen] + elif self.backend == "seq2seq": + assert contlen and not inplen, ( + "Selecting scored logits for Seq2SeqLM requires only cont. len" + ) + # only discard right-padding. + # the logits input to this fn only contain decoder-side tokens. + logits = logits[:contlen] + + return logits + + def loglikelihood_rolling( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[float]: + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + + # First, collect all windows from all requests + all_windows = [] # List of (request_idx, window) tuples + request_window_counts = [] # Track number of windows per request + + for req_idx, (string,) in enumerate( + tqdm( + [req.args for req in requests], + disable=(disable_tqdm or (self.rank != 0)), + ) + ): + rolling_token_windows: List[Tuple[List[int], List[int]]] = list( + map( + utils.make_disjoint_window, + utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.prefix_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case + windows = [(None,) + x for x in rolling_token_windows] + + # Store windows with their request index + all_windows.extend((req_idx, window) for window in windows) + request_window_counts.append(len(windows)) + + # Handle distributed case padding + pad_amnt = 0 + if self.world_size > 1: + mytensor = torch.tensor(len(all_windows), device=self.device) + gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() + pad_amnt = max(gathered) - gathered[self.rank] + if pad_amnt > 0: + all_windows += pad_amnt * [all_windows[0]] + + all_nlls = [] + batch_size = adaptive_batch_size or self.batch_size + for i in range(0, len(all_windows), batch_size): + batch = all_windows[i : i + batch_size] + # Extract just the windows for processing, keeping track of request indices + batch_indices, batch_windows = zip(*batch) + + batch_nlls = self._loglikelihood_tokens( + requests=batch_windows, + disable_tqdm=False, + override_bs=len(batch_windows), + ) + # Store results with their request indices + all_nlls.extend(zip(batch_indices, batch_nlls)) + + # Remove padding if necessary + if (self.world_size > 1) and (pad_amnt > 0): + all_nlls = all_nlls[:-pad_amnt] + + # Reconstruct per-request loglikelihoods + loglikelihoods = [] + current_idx = 0 + for window_count in request_window_counts: + # Get all nlls for this request + request_nlls = all_nlls[current_idx : current_idx + window_count] + # Sum up the nlls for this request (discarding is_greedy) + request_total = sum(nll[0] for _, nll in request_nlls) + loglikelihoods.append(request_total) + current_idx += window_count + + string = requests[len(loglikelihoods) - 1].args[0] + self.cache_hook.add_partial( + "loglikelihood_rolling", (string,), request_total + ) + + return loglikelihoods + + def _batch_scheduler(self, pos, n_reordered_requests): + sched = pos // int(len(n_reordered_requests) / self.batch_schedule) + if sched in self.batch_sizes: + return self.batch_sizes[sched] + if (len(self.batch_sizes) > 1) and ( + self.batch_sizes[sched - 1] == self.max_batch_size + ): + # if previous batch size is already maximal, skip recomputation + self.batch_sizes[sched] = self.max_batch_size + return self.batch_sizes[sched] + print( + f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" + ) + self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) + print(f"Determined largest batch size: {self.batch_sizes[sched]}") + return self.batch_sizes[sched] + + def _loglikelihood_tokens( + self, + requests: List[Tuple[Tuple[str, str], List[int], List[int]]], + disable_tqdm: bool = False, + override_bs: int = None, + ) -> List[Tuple[float, bool]]: + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = req[1] + req[2] + return -len(toks), tuple(toks) + + def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): + """Defines the key to group and lookup one-token continuations""" + # Use with group_by="contexts" (optional)" + # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. + # speeds up some multiple-choice tasks proportionally to the number of choices. + # groups requests by context+continuation[:-1] and infer on one request/group. + return req[-2] + req[-1][:-1] + + re_ord = Collator( + requests, + sort_fn=_collate, + group_by="contexts" + if self.backend == "causal" and self.logits_cache + else None, + group_fn=_lookup_one_token_cont, + ) + + # automatic (variable) batch size detection for vectorization + # pull longest context sample from request + n_reordered_requests = len(re_ord) + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else override_bs + if override_bs is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" + and n_reordered_requests > 0 + and not override_bs + else None + ) + + chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running loglikelihood requests", + ) + for chunk in chunks: + inps = [] + cont_toks_list = [] + inplens = [] + + conts = [] + encoder_attns = [] + + padding_len_inp = None + padding_len_cont = None + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + if self.backend == "causal": + total_length = len(context_enc) + len(continuation_enc) + if total_length > self.max_length + 1: + eval_logger.warning( + f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) " + f"exceeds model's maximum length ({self.max_length}). " + f"Truncating {total_length - self.max_length + 1} tokens from the left." + ) + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + elif self.backend == "seq2seq": + inp = torch.tensor( + (context_enc)[-self.max_length :], + dtype=torch.long, + device=self.device, + ) + (inplen,) = inp.shape + + # build encoder attn masks + encoder_attns.append(torch.ones_like(inp)) + + cont = torch.tensor( + (continuation_enc)[-self.max_length :], + # TODO: left-shift these? + # TODO: our code assumes we never end up truncating conts for either model type + dtype=torch.long, + device=self.device, + ) + (contlen,) = cont.shape + + conts.append(cont) + + padding_len_cont = ( + max(padding_len_cont, contlen) + if padding_len_cont is not None + else contlen + ) + + padding_len_inp = ( + max(padding_len_inp, inplen) + if padding_len_inp is not None + else inplen + ) + + inps.append(inp) # [1, inp_length] + cont_toks_list.append(continuation_enc) + inplens.append(inplen) + + # create encoder attn mask and batched conts, if seq2seq + call_kwargs = {} + if self.backend == "causal": + batched_inps = pad_and_concat( + padding_len_inp, inps, padding_side="right" + ) # [batch, padding_len_inp] + elif self.backend == "seq2seq": + # TODO: left-pad encoder inps and mask? + batched_inps = pad_and_concat( + padding_len_inp, inps + ) # [batch, padding_len_inp] + batched_conts = pad_and_concat( + padding_len_cont, conts + ) # [batch, padding_len_cont] + batched_encoder_mask = pad_and_concat( + padding_len_inp, encoder_attns + ) # [batch, padding_len_inp] + call_kwargs = { + "attn_mask": batched_encoder_mask, + "labels": batched_conts, + } + + multi_logits = F.log_softmax( + self._model_call(batched_inps, **call_kwargs), + dim=-1, + dtype=self.softmax_dtype, + ) # [batch, padding_length (inp or cont), vocab] + + for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( + chunk, multi_logits, inplens, cont_toks_list + ): + # Slice to original seq length + contlen = len(cont_toks) + # take only logits in the continuation + # (discard context toks if decoder-only ; discard right-padding) + # also discards + checks for "virtual tokens" in the causal LM's input window + # from prompt/prefix tuning tokens, if applicable + ctx_len = ( + inplen + (logits.shape[0] - padding_len_inp) + if self.backend == "causal" + else None + ) + logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) + logits = logits.unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + + # check for one-token continuation cache hits. + # noop in case group_by != "contexts" or no cache hit and returns the + # original args. Otherwise, expands the logits batch dimension and yields each + # batch along with matching continuation tokens and prompt strings. + # logits -> [1, seq, vocab] + for request_str, cont_toks, logits in re_ord.get_cache( + req_str=request_str, + cxt_toks=ctx_tokens, + cont_toks=cont_toks, + logits=logits, + ): + cont_toks = torch.tensor( + cont_toks, dtype=torch.long, device=self.device + ).unsqueeze(0) # [1, seq] + # Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]). + # i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens + # by choosing key with longest cont if group_by="contexts". + max_equal = ( + greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks + ).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( + -1 + ) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + res.append(answer) + + if request_str is not None: + # special case: loglikelihood_rolling produces a number of loglikelihood requests + # all with cache key None. instead do add_partial on the per-example level + # in the loglikelihood_rolling() function for those. + self.cache_hook.add_partial( + "loglikelihood", request_str, answer + ) + pbar.update(1) + + pbar.close() + + return re_ord.get_original(res) + + def generate_until( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[str]: + res = [] + + def _collate(req: Tuple[str, dict]): + """Defines the key for the sorted method""" + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(req[0]) + return -len(toks), req[0] + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests", + ) + adaptive_batch_size = None + if self.batch_size == "auto": + # using rolling window with maximum context + print("Passed argument batch_size = auto. Detecting largest batch size") + batch_size = self._detect_batch_size() + print(f"Determined Largest batch size: {batch_size}") + adaptive_batch_size = batch_size + # for each different set of kwargs, we execute all requests, by batch. + batch_size = ( + self.batch_size + if self.batch_size != "auto" + else adaptive_batch_size + if adaptive_batch_size is not None + else 0 + ) + batch_fn = ( + self._batch_scheduler + if self.batch_size == "auto" and not adaptive_batch_size + else None + ) + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + # group_fn=lambda x: x[1] -> x=(context, gen_kwargs) + re_ords = Collator( + [reg.args for reg in requests], + sort_fn=_collate, + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) + eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) + for chunk in chunks: + contexts, all_gen_kwargs = zip(*chunk) + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + # add EOS token to stop sequences + until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) + else: + raise ValueError( + f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" + ) + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + # set the max length in tokens of inputs ("context_enc") + if self.backend == "causal": + # max len for inputs = max length, minus room to generate the max new tokens + max_ctx_len = self.max_length - max_gen_toks + assert max_ctx_len > 0, ( + f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." + ) + elif self.backend == "seq2seq": + # max len for inputs = encoder's whole max_length + max_ctx_len = self.max_length + + # encode, pad, and truncate contexts for this batch + context_enc, attn_masks = self.tok_batch_encode( + contexts, + left_truncate_len=max_ctx_len, + truncation=self.truncation, + ) + context_enc = context_enc.to(self.device) + attn_masks = attn_masks.to(self.device) + + if "max_length" not in kwargs: + kwargs["max_length"] = context_enc.shape[1] + max_gen_toks + + # perform batched generation + cont = self._model_generate( + context=context_enc, + attention_mask=attn_masks, + stop=until, + **kwargs, + ) + + cont_toks_list = cont.tolist() + for cont_toks, context in zip(cont_toks_list, contexts): + # discard context + left-padding toks if using causal decoder-only LM + if self.backend == "causal": + cont_toks = cont_toks[context_enc.shape[1] :] + + s = self.tok_decode(cont_toks) + + # use secondary stop seqs to cut off should-have-been-stopped content post-hoc + for term in until: + if len(term) > 0: + # ignore '' separator, + # for seq2seq case where self.tok_decode(self.eot_token_id) = '' + s = s.split(term)[0] + + res.append(s) + + self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + + return res + + def apply_chat_template( + self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True + ) -> str: + """ + Method to apply a chat template to a list of chat history between user and model. + """ + try: + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + except jinja2.exceptions.TemplateError: + eval_logger.warning( + "Failed to apply chat template. removing the system role in chat history." + ) + chat_history = [msg for msg in chat_history if msg["role"] != "system"] + chat_templated = self.tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + + return chat_templated + + def get_model_info(self) -> dict: + """ + Method to get Hugging Face model information for experiment reproducibility. + """ + + def get_model_num_params(model) -> int: + if hasattr(model, "num_parameters"): + return model.num_parameters() + if hasattr(model, "parameters"): + return sum(p.numel() for p in model.parameters()) + else: + return -1 + + def get_model_dtype(model) -> str: + if hasattr(model, "dtype"): + return model.dtype + else: + return "" + + def get_model_sha(pretrained: str, revision: str) -> str: + try: + model_info = HfApi().model_info(repo_id=pretrained, revision=revision) + return model_info.sha + except Exception as e: + eval_logger.debug( + f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}" + ) + return "" + + model_info = { + "model_num_parameters": get_model_num_params(self._model), + "model_dtype": get_model_dtype(self._model), + "model_revision": self.revision, + "model_sha": get_model_sha(self.pretrained, self.revision), + } + if self.peft: + model_info["peft_sha"] = get_model_sha(self.peft, self.revision) + if self.delta: + model_info["delta_sha"] = get_model_sha(self.delta, self.revision) + return model_info diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e17fa224b22fbbef442c94e13d4f7c237d3c647d --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/utils.py @@ -0,0 +1,854 @@ +import collections +import fnmatch +import gc +import itertools +import logging +import time +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) + +import torch +import transformers + + +eval_logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from PIL import Image + from transformers import PreTrainedTokenizerBase + from transformers.configuration_utils import PretrainedConfig + + +def chunks(iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + for i, x in enumerate(iter): + arr.append(x) + if len(arr) == (fn(i, iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +class MultiChoice: + def __init__(self, choices) -> None: + self.choices = choices + + # Simple wildcard support (linux filename patterns) + def __contains__(self, values) -> bool: + for value in values.split(","): + if len(fnmatch.filter(self.choices, value)) == 0: + eval_logger.info("Available tasks to choose:") + for choice in self.choices: + eval_logger.info(f" - {choice}") + raise ValueError("'{}' is not in task list".format(value)) + return True + + def __iter__(self) -> Iterator: + for choice in self.choices: + yield choice + + +class Grouper: + """ + takes an array `arr` and function `fn` and returns a dictionary + with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all + objects in `arr` satisfying `key == fn(ob)`. + """ + + def __init__(self, arr, fn) -> None: + # self.orig_arr = arr + self.size = len(arr) + arr = list(enumerate(arr)) + + def group_return_dict(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + return res + + arr = group_return_dict(arr, lambda x: fn(x[1])) + + # self.arr has format Dict[Tuple[int, ]] + self.arr = arr + self._grouped = None + + def get_grouped(self): + # return the contents but not indices for our grouped dict. + if self._grouped: + return self._grouped + grouped = {} + for key in self.arr.keys(): + # drop the index from each element of self.arr + grouped[key] = [y[1] for y in self.arr[key]] + self._grouped = grouped + return grouped + + def get_original(self, grouped_dict): + # take in a grouped dictionary with e.g. results for each key listed + # in the same order as the instances in `self.arr`, and + # return the results in the same (single list) order as `self.orig_arr`. + res = [None] * self.size + cov = [False] * self.size + # orig = [None] * self.size + + assert grouped_dict.keys() == self.arr.keys() + + for key in grouped_dict.keys(): + for (ind, _), v in zip(self.arr[key], grouped_dict[key]): + res[ind] = v + cov[ind] = True + # orig[ind] = _ + + assert all(cov) + # assert orig == self.orig_arr + + return res + + +def pad_and_concat( + max_length: int, + tensors: List[torch.Tensor], + padding_side: Literal["right", "left"] = "right", +): + """ + Method for padding a list of tensors given the maximum tensor + length in the batch. Used for batching inputs and continuations in + seq2seq models. + """ + assert padding_side == "left" or padding_side == "right", ( + f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" + ) + + for i, tensor in enumerate(tensors): + if len(tensor.shape) == 2: + tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size + tensor_len = tensor.shape[0] + if tensor_len < max_length: + if padding_side == "right": + # right-pad + tensors[i] = torch.cat( + [ + tensor, # [seq] + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + ], + dim=0, + ).unsqueeze(0) + else: + # left-pad + tensors[i] = torch.cat( + [ + torch.zeros( + max_length - tensor_len, + dtype=torch.long, + device=tensor.device, + ), # [padding_length - seq] + tensor, # [seq] + ], + dim=0, + ).unsqueeze(0) + else: + tensors[i] = tensor.unsqueeze(0) + + return torch.cat(tensors, dim=0) + + +def clear_torch_cache() -> None: + gc.collect() + torch.cuda.empty_cache() + + +def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: + """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" + if isinstance(dtype, str) and dtype != "auto": + # Convert `str` args torch dtype: `float16` -> `torch.float16` + _torch_dtype = getattr(torch, dtype) + else: + _torch_dtype = dtype + return _torch_dtype + + +class MultiTokenEOSCriteria(transformers.StoppingCriteria): + """Criteria to stop on the specified multi-token sequence.""" + + def __init__( + self, + sequence: str, + tokenizer: transformers.PreTrainedTokenizer, + initial_decoder_input_length: int, + batch_size: int, + ) -> None: + self.initial_decoder_input_length = initial_decoder_input_length + self.done_tracker = [False] * batch_size + self.sequence = sequence + self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) + # print(sequence, self.sequence_ids) + # we look back for 2 more tokens than it takes to encode our stop sequence + # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` + # and we don't want to mistakenly not stop a generation because our + # (string) stop sequence was output in a different tokenization + + # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, + # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized + # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described. + self.sequence_id_len = len(self.sequence_ids) + 2 + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs) -> bool: + # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence + lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :] + + lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :] + + lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) + + for i, done in enumerate(self.done_tracker): + if not done: + self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] + return False not in self.done_tracker + + +def stop_sequences_criteria( + tokenizer: transformers.PreTrainedTokenizer, + stop_sequences: List[str], + initial_decoder_input_length: int, + batch_size: int, +) -> transformers.StoppingCriteriaList: + return transformers.StoppingCriteriaList( + [ + *[ + MultiTokenEOSCriteria( + sequence, tokenizer, initial_decoder_input_length, batch_size + ) + for sequence in stop_sequences + ], + ] + ) + + +def undistribute(iterable): + """ + Undoes https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute . + + Re-interleaves results that have been split using more_itertools.distribute: + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + >>> undistribute([group_1, group_2]) + [1, 2, 3, 4, 5, 6] + + Handles non-uniform component lengths: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + >>> undistribute(children) + [1, 2, 3, 4, 5, 6, 7] + + Also handles when some iterables are empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + >>> undistribute(children) + [1, 2, 3] + + """ + + return [ + x + for x in itertools.chain.from_iterable( + itertools.zip_longest(*[list(x) for x in iterable]) + ) + if x is not None + ] + + +def retry_on_specific_exceptions( + on_exceptions: List[Type[Exception]], + max_retries: Optional[int] = None, + backoff_time: float = 3.0, + backoff_multiplier: float = 1.5, + on_exception_callback: Optional[Callable[[Exception, float], Any]] = None, +): + """Retry on an LLM Provider's rate limit error with exponential backoff + For example, to use for OpenAI, do the following: + ``` + from openai import RateLimitError + + # Recommend specifying max_retries to avoid infinite loops! + @retry_on_specific_exceptions([RateLimitError], max_retries=3) + def completion(...): + # Wrap OpenAI completion function here + ... + ``` + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + sleep_time = backoff_time + attempt = 0 + while max_retries is None or attempt < max_retries: + try: + return func(*args, **kwargs) + except tuple(on_exceptions) as e: + if on_exception_callback is not None: + on_exception_callback(e, sleep_time) + time.sleep(sleep_time) + sleep_time *= backoff_multiplier + attempt += 1 + + return wrapper + + return decorator + + +class Collator: + """ + A class for reordering and batching elements of an array. + + This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. + + Objects of this class have the group_by attribute which determines the method for grouping + the data while batching it. Three options include "gen_kwargs", "contexts", or None: + If group_by == "gen_kwargs" then requests will be grouped by gen_kwargs + If group_by == "contexts" then requests will be grouped by context + cont[:-1] + If None then requests will just be reordered by length descending. + """ + + def __init__( + self, + arr: List, + sort_fn: Callable = lambda x: x, + group_fn: Callable = lambda x: x[1], + group_by: Union[Literal["gen_kwargs", "contexts"], None] = None, + ) -> None: + self._group_by = group_by + # 0 indices are enumerated indices. Apply functions to original arr. + self._sort_fn = lambda x: sort_fn(x[1]) + self._group_fn = lambda x: group_fn(x[1]) + self._reorder_indices: List = [] + self._size = len(arr) + self._arr_with_indices: Union[Dict, Tuple[Tuple[int, Any], ...]] = tuple( + enumerate(arr) + ) # [indices, (arr)] + if self._group_by == "contexts": + self._group_by_context() + elif self._group_by == "gen_kwargs": + self._group_by_index() + + def _group_by_index(self) -> None: + """Group the elements of a list based on their indices.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="gen_kwargs" + ) + + def _group_by_context(self) -> None: + """Group the array with indices by context.""" + self._arr_with_indices = self.group( + self._arr_with_indices, fn=self._group_fn, group_by="contexts" + ) + + def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: + """ + Generates and yields batches from the reordered array. The method of grouping and batching + depends on the parameter `group_by`. + If `group_by` is set to "gen_kwargs", it will batch the + re-ordered values with same gen_kwargs for each batch. + If `group_by` is "contexts", it caches the requests by context before batching. + If `group_by` is neither "gen_kwargs" nor "contexts", it yields the reordered array + + Parameters: + - n (int): The size of each batch. Defaults to 1. + - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of + each batch. Optional, defaults to None. + + Returns: + Iterator: An iterator over batches of reordered elements grouped as per the `group_by` + attribute. + + Yields: + List of batched elements according to the `group_by` attribute. + """ + if self._group_by == "gen_kwargs": + for ( + key, + values, + ) in self._arr_with_indices.items(): # type: ignore + values = self._reorder(values) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + elif self._group_by == "contexts": + # Get one sample from each key. + # Select longest continuation per group to ensure sufficient context logits + values = self._reorder( + [ + max(value, key=lambda x: len(x[1][-1])) + for value in self._arr_with_indices.values() + ] + ) + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + else: + values = self._reorder(self._arr_with_indices) # type: ignore + batch = self.get_chunks(values, n=n, fn=batch_fn) + yield from batch + + def get_cache( + self, + req_str: Tuple[str, str] = None, + cxt_toks: List[int] = None, + cont_toks: List[int] = None, + logits: torch.Tensor = None, + ) -> Iterator[Tuple[Tuple[str, str], List[int], torch.Tensor]]: + """ + Retrieves cached single-token continuations and their associated arguments, updating indices as necessary. + + The behavior of this function varies depending on how the `group_by` attribute is set: + + - When `group_by` is "contexts": + The function identifies single-token continuations by checking for keys that equate to + [context+continuation][-1] and logs the indices for re-ordering. + In this mode, this function can work in two scenarios: + + 1. Cache Hit - Single Match: + If a single matching context-continuation pair is found in the cache, + the function yields the original arguments. + + 2. Cache Hit - Multiple Matches: + If multiple matching context-continuation pairs are found in the cache, + the function expands the logits batch dimension to match the number of cache hits. + It updates the original requests and continuation tokens. + + - When `group_by` is not set to "contexts": + This method yields the original arguments, logits and continuation tokens, + without checking for one-token continuations. + + Parameters: + - req_str (tuple[str, str]): Original strings used for CachingLM. + - cxt_toks (list[int]): Full context tokens used for lookup. + - cont_toks (list[int]): Continuation tokens for which logits were generated. + - logits (torch.Tensor [1, seq_length, vocab_size]): Logits generated by the model given context and continuation keys. + + Yields: + - Iterator: + - req_str (tuple[str, str]): strings used for CachingLM. + - cont_toks (list[int]) : continuation tokens. + - logits (torch.Tensor [1, seq_length, vocab_size]): The original logits (repeated cache hit times) + """ + if self._group_by == "contexts": + cache_hit: List[ + Tuple[int, Tuple[Tuple[str, str], List[int], List[int]]] + ] = self._arr_with_indices.pop(tuple(cxt_toks + cont_toks[:-1])) + if (cache_size := len(cache_hit)) == 1: + self._reorder_indices.extend(x[0] for x in cache_hit) + yield req_str, cont_toks, logits + else: + # If we have matching requests then expand the batch dimension (no-op) and + # yield each along with its corresponding args. + multilogits = logits.expand(cache_size, -1, -1).chunk(cache_size) + indices, req_str, cont_toks = zip( + *[(x[0], x[1][0], x[-1][-1]) for x in cache_hit] + ) + self._reorder_indices.extend(indices) + for c_key, cont_tok, logit in zip(req_str, cont_toks, multilogits): + yield c_key, cont_tok, logit + else: + yield req_str, cont_toks, logits + + def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> Iterator: + """ + Reorders the elements in the array based on the sorting function. + + Parameters: + - arr (list | tuple[tuple[int, Any], ...]]): The array or iterable to be reordered. + + Yields: + Iterator + """ + arr = sorted(arr, key=self._sort_fn) + if not self._group_by == "contexts": + # If grouped by contexts then indices will be set in get_cache() + self._reorder_indices.extend([x[0] for x in arr]) + yield from [x[1] for x in arr] + + def get_original(self, newarr: List) -> List: + """ + Restores the original order of elements from the reordered list. + + Parameters: + - newarr (list): The reordered array. + + Returns: + list: The array with elements restored to their original order. + """ + res = [None] * self._size + cov = [False] * self._size + + for ind, v in zip(self._reorder_indices, newarr): + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + + def __len__(self): + return self._size + + @staticmethod + def group( + arr: Iterable, + fn: Callable, + group_by: Literal["gen_kwargs", "contexts"] = "gen_kwargs", + ) -> dict: + """ + Groups elements of an iterable based on a provided function. + + + The `group_by` parameter determines the method of grouping. + If `group_by` is "contexts", the elements are grouped by [context + cont][:-1]. + If `group_by` is "gen_kwargs", the elements are grouped based on the gen_kwargs dict. + + Parameters: + - arr (Iterable): The iterable to be grouped. + - fn (Callable): The function to determine the grouping. + - values (bool): If True, returns the values of the group. Defaults to False. + + Returns: + Iterator: An iterable of grouped elements. + """ + res = collections.defaultdict(list) + for ob in arr: + # where ob == [context + cont] + if group_by == "contexts": + res[tuple(fn(ob))].append(ob) + else: + try: + hashable_dict = tuple( + ( + key, + tuple(value) + if isinstance(value, collections.abc.Iterable) + else value, + ) + for key, value in sorted(fn(ob).items()) + ) + res[hashable_dict].append(ob) + except (TypeError, AttributeError): + res[tuple(fn(ob))].append(ob) + return res + + @staticmethod + def get_chunks(_iter, n: int = 0, fn=None): + """ + Divides an iterable into chunks of specified size or based on a given function. + Useful for batching + + Parameters: + - iter: The input iterable to be divided into chunks. + - n: An integer representing the size of each chunk. Default is 0. + - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. + + Returns: + An iterator that yields chunks of the input iterable. + + Example usage: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for chunk in chunks(data, 3): + print(chunk) + ``` + Output: + ``` + [1, 2, 3] + [4, 5, 6] + [7, 8, 9] + [10] + ``` + """ + arr = [] + _iter = tuple(_iter) + for i, x in enumerate(_iter): + arr.append(x) + if len(arr) == (fn(i, _iter) if fn else n): + yield arr + arr = [] + + if arr: + yield arr + + +def configure_pad_token( + tokenizer: "PreTrainedTokenizerBase", + model_config: Optional["PretrainedConfig"] = None, +) -> "PreTrainedTokenizerBase": + """ + This function checks if the (Hugging Face) tokenizer has a padding token and sets it if not present. + Some tokenizers require special handling. + + Args: + tokenizer: The tokenizer for which the padding token is to be handled. + model_config: The configuration of the model. Default is None. + + Returns: + The tokenizer after the padding token has been handled. + + Raises: + AssertionError: If the tokenizer is of type RWKVWorldTokenizer or Rwkv5Tokenizer and the padding token id is not 0. + """ + if tokenizer.pad_token: + pass + elif tokenizer.unk_token: + tokenizer.pad_token_id = tokenizer.unk_token_id + elif tokenizer.eos_token: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + # handle special cases + if model_config and getattr(model_config, "model_type", None) == "qwen": + # Qwen's trust_remote_code tokenizer does not allow for adding special tokens + tokenizer.pad_token = "<|endoftext|>" + elif ( + tokenizer.__class__.__name__ == "RWKVWorldTokenizer" + or tokenizer.__class__.__name__ == "Rwkv5Tokenizer" + ): + # The RWKV world tokenizer, does not allow for adding special tokens / setting the pad token (which is set as 0) + # The additional tokenizer name check is needed, as there exists rwkv4 models with neox tokenizer + # --- + # Note that the world tokenizer class name, might change in the future for the final huggingface merge + # https://github.com/huggingface/transformers/pull/26963 + assert tokenizer.pad_token_id == 0 + else: + tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + + return tokenizer + + +def replace_placeholders( + string: str, default_placeholder: str, image_token: str, max_images: int +): + """ + A utility function used for local multimodal models. It locates all `placeholder` string + occurrences in the given input `string_` and replaces the first `max_count` instances with + `replacement`, and all subsequent occurrences with the empty string. + + This is used to replace placeholder tags by model-specific image tokens like <|image_pad|> + and to allow for only the first `max_count` images to be passed to a model if desired. + + :param string: The original string containing placeholders. + :param default_placeholder: The placeholder text to be replaced. + :param image_token: The token to replace the placeholder with. + :param max_images: The maximum number of replacements to make. + :return: The string with placeholders replaced. + """ + count = 0 + result = [] + + parts = string.split(default_placeholder) + for part in parts[:-1]: # Iterate through all but the last part + result.append(part) + if count < max_images: + result.append(image_token) + count += 1 + elif default_placeholder != image_token: + result.append(default_placeholder) + + # Add the last part of the string + result.append(parts[-1]) + return "".join(result) + + +def flatten_image_list(images: List[List]): + """ + Takes in a list of lists of images, and returns a single list of all images in order. + Used for some multimodal models like Llava-1.5 which expects this flattened-list format for its image processor. + + :param images: A list of lists of PIL images. + :return: a list of PIL images, via concatenating all the sub-lists in order. + """ + return [image for image_list in images for image in image_list] + + +def handle_stop_sequences( + until: Union[str, List[str], None], eos: Optional[str] +) -> List[str]: + """Ensures that the `until` parameter is a list of stop sequences and includes the EOS token.""" + if isinstance(until, str): + until = [until] + elif until is None: + until = [] + elif not isinstance(until, list): + raise ValueError( + f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" + ) + + if eos is not None and eos not in until: + until.append(eos) + return until + + +def resize_image( + image: "Image.Image", + width: Optional[int] = None, + height: Optional[int] = None, + max_dimension: Optional[int] = None, + keep_aspect_ratio: bool = True, + resample_filter: Union[int, str] = "Image.BICUBIC", + min_width: int = 1, + min_height: int = 1, +) -> "Image.Image": + """ + Resizes a PIL Image object with flexible options. + + Args: + image: The PIL Image object to resize. + width: Target width in pixels. + height: Target height in pixels. + max_dimension: Maximum size for the longer dimension of the image. + keep_aspect_ratio: If True (default) and both width and height are provided, + the image is resized to fit within these dimensions while + maintaining its aspect ratio. If False, the image is stretched + to the exact width and height. + resample_filter: The resampling filter to use for resizing. + Defaults to Image.BICUBIC. + min_width: Minimum width for the resized image. Defaults to 1. + min_height: Minimum height for the resized image. Defaults to 1. + + Returns: + The resized PIL Image object. If no resize parameters are provided + or if the image already meets the criteria, the original image is returned. + + Order of precedence for resizing: + 1. If width AND height are provided: + - If keep_aspect_ratio is True: Fits image within bounds, preserving aspect ratio. + - If keep_aspect_ratio is False: Resizes to exact dimensions (may distort). + 2. Else if only width is provided: Calculates height proportionally. + 3. Else if only height is provided: Calculates width proportionally. + 4. Else if max_dimension is provided: Resizes the longest side to max_dimension + and scales the other side proportionally. + 5. If none of the above are provided, returns the original image. + """ + original_width, original_height = image.size + + # If no arguments are provided, return the original image + if width is None and height is None and max_dimension is None: + return image + + new_width = original_width + new_height = original_height + + if width is not None and height is not None: + # No resize needed if image is already smaller than target dimensions + if original_width <= width and original_height <= height: + return image + + if keep_aspect_ratio: + # Calculate the ratio to fit within the target dimensions + ratio = min(width / original_width, height / original_height) + new_width = int(original_width * ratio) + new_height = int(original_height * ratio) + else: + # Stretch to exact dimensions + new_width = width + new_height = height + elif width is not None: + # No resize needed if width is already smaller + if original_width <= width: + return image + # Calculate height proportionally + new_width = width + new_height = int((original_height / original_width) * new_width) + elif height is not None: + # No resize needed if height is already smaller + if original_height <= height: + return image + # Calculate width proportionally + new_height = height + new_width = int((original_width / original_height) * new_height) + elif max_dimension is not None: + # No resize needed if both dimensions are smaller than max_dimension + if max(original_height, original_width) <= max_dimension: + return image + + if original_width > original_height: + # Width is the longer side + new_width = max_dimension + new_height = int((original_height / original_width) * new_width) + else: + # Height is the longer side or sides are equal + new_height = max_dimension + new_width = int((original_width / original_height) * new_height) + + # Ensure dimensions are at least minimum values + new_width = max(min_width, new_width) + new_height = max(min_height, new_height) + + # Perform the resize operation with the calculated dimensions + return image.resize((new_width, new_height), resample_filter) + + +def truncate_tokens( + tokens: List[int], + max_length: int, + tokenizer: "PreTrainedTokenizerBase", + strategy: str = "left", +): + if strategy == "left": + return tokens[-max_length:] + elif strategy == "right": + return tokens[:max_length] + elif strategy == "middle": + # Truncate the middle of the sequence + left_length = max_length // 2 + right_length = max_length - left_length + return tokens[:left_length] + tokens[-right_length:] + return None diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/verifier.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..28f7ec6cc78d40d8df7e8d894d8d8d83222bffa7 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/models/verifier.py @@ -0,0 +1,147 @@ +import torch +import logging +import ast +import re +import numpy as np +import textwrap + +logger = logging.getLogger(__name__) + +class CodeVerifier: + def __init__(self, model, tokenizer, device="cuda"): + self.model = model + self.tokenizer = tokenizer + self.device = device + + self.yes_ids, self.no_ids = [], [] + for t in ["Yes", " Yes"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) == 1: self.yes_ids.append(ids[0]) + for t in ["No", " No"]: + ids = self.tokenizer.encode(t, add_special_tokens=False) + if len(ids) == 1: self.no_ids.append(ids[0]) + + def _extract_python_code(self, text): + text = text.strip() + match = re.search(r"```python\s*(.*?)```", text, re.DOTALL) + if match: return match.group(1) + match_generic = re.search(r"```\s*(.*?)```", text, re.DOTALL) + if match_generic: return match_generic.group(1) + return text + + def check_syntax(self, code_str): + clean_code = self._extract_python_code(code_str) + try: + if len(clean_code.strip()) < 5: return False + ast.parse(clean_code) + return True + except: + return False + + def compute_confidence(self, logits): + if logits is None: return 0.0 + probs = torch.softmax(logits, dim=-1) + max_probs, _ = torch.max(probs, dim=-1) + log_probs = torch.log(max_probs + 1e-10) + return torch.exp(torch.mean(log_probs)).item() + + def svf_score(self, prompt, code_str, task_type="code"): + + max_len = 2000 + if len(code_str) > max_len: + if task_type == "reasoning": + truncated_code = code_str[:500] + "\n...[truncated]...\n" + code_str[-(max_len-500):] + else: + truncated_code = code_str[-max_len:] + else: + truncated_code = code_str + + if task_type == "code": + prompt_template = f""" + You are an expert programming contest judge. Your task is to evaluate a generated solution for a given problem based on correctness, efficiency, and adherence to constraints. + + [Problem Statement] + {prompt} + [/Problem Statement] + + [Proposed Python Solution] + ```python + {truncated_code} + ``` + [/Proposed Python Solution] + + **Analysis Steps:** + 1. Correctness: Does the core algorithm correctly solve the problem? + 2. Efficiency: Is the time complexity acceptable for the given constraints? + 3. Edge Cases & Constraints: Does the code handle all rules and edge cases? + + **Conclusion**: Based on your analysis, is the solution likely to be fully correct? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "math": + prompt_template = f""" + You are an expert mathematician and competition judge. Your task is to evaluate a proposed mathematical solution for a given problem based on its logical rigor and accuracy. + + [Math Problem] + {prompt} + [/Math Problem] + + [Proposed Mathematical Solution] + {truncated_code} + [/Proposed Mathematical Solution] + + **Analysis Steps:** + 1. Reasoning Validity: Are the logical steps and mathematical properties applied correctly? + 2. Calculation Accuracy: Are the intermediate calculations or algebraic manipulations accurate? + 3. Goal Alignment: Does the current reasoning path directly lead toward the final answer required by the problem? + + **Conclusion**: Based on your analysis, is this solution path sound and likely to result in the correct final answer? Answer with a single word: Yes or No. + **Answer:** """ + + elif task_type == "reasoning": + prompt_template = f""" + You are an expert reading comprehension and faithfulness judge. Your task is to evaluate a generated answer based on the provided context and question. + + [Context and Question] + {prompt} + [/Context and Question] + + [Proposed Answer] + {truncated_code} + [/Proposed Answer] + + **Analysis Steps :** + 1. Faithfulness: Is the answer an exact, literal span from the context? + 2. Relevance: Does the answer directly address the specific question asked without hallucinating external information? + 3. Accuracy: Does the provided context strictly support this answer? + + **Conclusion**: Based on your analysis, is the answer fully faithful to the context and correct? Answer with a single word: Yes or No. + **Answer:** """ + + else: + prompt_template = f"Is the following answer correct?\nQuestion: {prompt}\nAnswer: {truncated_code}\nAnswer Yes or No.\nAnswer:" + + verify_text = textwrap.dedent(prompt_template).strip() + input_ids = self.tokenizer(verify_text, return_tensors="pt").input_ids.to(self.device) + + if input_ids.shape[1] > self.model.config.max_position_embeddings - 16: + logger.warning("Verifier input is too long, truncating from the left.") + input_ids = input_ids[:, - (self.model.config.max_position_embeddings - 16):] + + with torch.no_grad(): + outputs = self.model(input_ids) + logits = outputs.logits[0, -1, :] + + yes_score = max((logits[i].item() for i in self.yes_ids if i < logits.shape[-1]), default=-float('inf')) + no_score = max((logits[i].item() for i in self.no_ids if i < logits.shape[-1]), default=-float('inf')) + + if yes_score == -float('inf') and no_score == -float('inf'): return 0.5 + + probs = torch.softmax(torch.tensor([yes_score, no_score]), dim=0) + return probs[0].item() + + def get_reward(self, prompt, code_str, mode="confidence", problem_data=None, current_logits=None, task_type="code"): + if mode == "svf": + return self.svf_score(prompt, code_str, task_type=task_type) + else: + return self.compute_confidence(current_logits) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/prompts/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c2ce897dcde522ac82d0cbe0e06db1e02b1b72 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/prompts/__init__.py @@ -0,0 +1,128 @@ +import ast +import logging +import os +from typing import Dict + +from dllm_eval import utils + + +eval_logger = logging.getLogger(__name__) + +# Prompt library. +# Stores prompts in a dictionary indexed by 2 levels: +# prompt category name, and prompt name. +# This allows us to access prompts +PROMPT_REGISTRY: Dict[str, Dict[str, str]] = { + "qa-basic": { + "question-newline-answer": "Question: {{question}}\nAnswer:", + "q-newline-a": "Q: {{question}}\nA:", + }, +} + + +def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None): + # unpack prompt name + category_name, prompt_name = prompt_id.split(":") + if subset_name is None: + dataset_full_name = dataset_name + else: + dataset_full_name = f"{dataset_name}-{subset_name}" + eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}") + if category_name == "promptsource": + try: + from promptsource.templates import DatasetTemplates + except ModuleNotFoundError as exception: + raise type(exception)( + "Tried to load a Promptsource template, but promptsource is not installed ", + "please install promptsource via pip install lm-eval[promptsource] or pip install -e .[promptsource]", + ) + try: + if subset_name is None: + prompts = DatasetTemplates(dataset_name=dataset_name) + else: + prompts = DatasetTemplates( + dataset_name=dataset_name, subset_name=subset_name + ) + except Exception: + raise ValueError(f"{dataset_name} and {subset_name} not found") + if prompt_name in prompts.all_template_names: + return prompts[prompt_name] + else: + raise ValueError( + f"{prompt_name} not in prompt list {prompts.all_template_names}" + ) + elif ".yaml" in category_name: + import yaml + + with open(category_name, "rb") as file: + prompt_yaml_file = yaml.full_load(file) + + prompt_string = prompt_yaml_file["prompts"][prompt_name] + return PromptString(prompt_string) + else: + try: + return PROMPT_REGISTRY[category_name][prompt_name] + except Exception: + raise ValueError( + f"expected only a single `:` as separator between \ + prompt category and name, but got `{prompt_id}` instead" + ) + + +def load_prompt_list( + use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs +): + category_name, prompt_name = use_prompt.split(":") + + if category_name == "promptsource": + from promptsource.templates import DatasetTemplates + + if subset_name is None: + prompts = DatasetTemplates(dataset_name=dataset_name) + else: + prompts = DatasetTemplates( + dataset_name=dataset_name, subset_name=subset_name + ) + + prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) + + elif ".yaml" in category_name: + import yaml + + if yaml_path is not None: + category_name = os.path.realpath(os.path.join(yaml_path, category_name)) + + with open(category_name, "rb") as file: + prompt_yaml_file = yaml.full_load(file) + + prompt_list = utils.pattern_match( + prompt_name, prompt_yaml_file["prompts"].keys() + ) + + # category_name, *prompt_name = use_prompt.split(":") + # TODO allow to multiple prompt naming + # if len(prompt_name) > 1: + # prompt_list = [] + # for prompt in prompt_name: + # prompt_list.append(utils.pattern_match(prompt_name, prompts.all_template_names)) + # else: + # prompt_list = utils.pattern_match(prompt_name, prompts.all_template_names) + return [":".join([category_name, prompt]) for prompt in prompt_list] + + +class PromptString: + def __init__(self, prompt_string): + self.prompt_string = prompt_string + + def apply(self, doc): + doc_to_text = self.prompt_string["doc_to_text"] + doc_to_target = self.prompt_string["doc_to_target"] + + # TODO need a way to process doc_to_choice + if "doc_to_choice" in self.prompt_string: + raise NotImplementedError("Not yet implemented to accept doc_to_choice") + + text_string = utils.apply_template(doc_to_text, doc) + target_string = utils.apply_template(doc_to_target, doc) + + return [text_string, target_string] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/__init__.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73d896452e06c2cc2909c290de70dcf87b0c6f90 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/__init__.py @@ -0,0 +1,670 @@ +import collections +import inspect +import logging +import os +from functools import partial +from typing import Dict, List, Mapping, Optional, Union + +from dllm_eval import utils +from dllm_eval.api.group import ConfigurableGroup, GroupConfig +from dllm_eval.api.task import ConfigurableTask, Task +from dllm_eval.evaluator_utils import get_subtask_list + + +GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) + +eval_logger = logging.getLogger(__name__) + + +class TaskManager: + """TaskManager indexes all tasks from the default `dllm_eval/tasks/` + and an optional directory if provided. + + """ + + def __init__( + self, + verbosity: Optional[str] = None, + include_path: Optional[Union[str, List]] = None, + include_defaults: bool = True, + metadata: Optional[dict] = None, + ) -> None: + if verbosity is not None: + utils.setup_logging(verbosity) + self.include_path = include_path + self.metadata = metadata + self._task_index = self.initialize_tasks( + include_path=include_path, include_defaults=include_defaults + ) + self._all_tasks = sorted(list(self._task_index.keys())) + + self._all_groups = sorted( + [x for x in self._all_tasks if self._task_index[x]["type"] == "group"] + ) + self._all_subtasks = sorted( + [ + x + for x in self._all_tasks + if self._task_index[x]["type"] in ["task", "python_task"] + ] + ) + self._all_tags = sorted( + [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"] + ) + + self.task_group_map = collections.defaultdict(list) + + def initialize_tasks( + self, + include_path: Optional[Union[str, List]] = None, + include_defaults: bool = True, + ) -> dict[str, dict]: + """Creates a dictionary of tasks indexes. + + :param include_path: Union[str, List] = None + An additional path to be searched for tasks recursively. + Can provide more than one such path as a list. + :param include_defaults: bool = True + If set to false, default tasks (those in dllm_eval/tasks/) are not indexed. + return + Dictionary of task names as key and task metadata + """ + if include_defaults: + all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"] + else: + all_paths = [] + if include_path is not None: + if isinstance(include_path, str): + include_path = [include_path] + all_paths.extend(include_path) + + task_index = {} + for task_dir in all_paths: + tasks = self._get_task_and_group(task_dir) + task_index = {**tasks, **task_index} + + return task_index + + @property + def all_tasks(self): + return self._all_tasks + + @property + def all_groups(self): + return self._all_groups + + @property + def all_subtasks(self): + return self._all_subtasks + + @property + def all_tags(self): + return self._all_tags + + @property + def task_index(self): + return self._task_index + + def list_all_tasks( + self, list_groups=True, list_tags=True, list_subtasks=True + ) -> str: + from pytablewriter import MarkdownTableWriter + + def sanitize_path(path): + # don't print full path if we are within the dllm_eval/tasks dir ! + # if we aren't though, provide the full path. + if "dllm_eval/tasks/" in path: + return "dllm_eval/tasks/" + path.split("dllm_eval/tasks/")[-1] + else: + return path + + group_table = MarkdownTableWriter() + group_table.headers = ["Group", "Config Location"] + gt_values = [] + for g in self.all_groups: + path = self.task_index[g]["yaml_path"] + if path == -1: + path = "---" + else: + path = sanitize_path(path) + gt_values.append([g, path]) + group_table.value_matrix = gt_values + + tag_table = MarkdownTableWriter() + tag_table.headers = ["Tag"] + tag_table.value_matrix = [[t] for t in self.all_tags] + + subtask_table = MarkdownTableWriter() + subtask_table.headers = ["Task", "Config Location", "Output Type"] + st_values = [] + for t in self.all_subtasks: + path = self.task_index[t]["yaml_path"] + + output_type = "" + + # read the yaml file to determine the output type + if path != -1: + config = utils.load_yaml_config(path, mode="simple") + if "output_type" in config: + output_type = config["output_type"] + elif ( + "include" in config + ): # if no output type, check if there is an include with an output type + include_path = path.split("/")[:-1] + config["include"] + include_config = utils.load_yaml_config(include_path, mode="simple") + if "output_type" in include_config: + output_type = include_config["output_type"] + + if path == -1: + path = "---" + else: + path = sanitize_path(path) + st_values.append([t, path, output_type]) + subtask_table.value_matrix = st_values + + result = "\n" + if list_groups: + result += group_table.dumps() + "\n\n" + if list_tags: + result += tag_table.dumps() + "\n\n" + if list_subtasks: + result += subtask_table.dumps() + "\n\n" + return result + + def match_tasks(self, task_list: list[str]) -> list[str]: + return utils.pattern_match(task_list, self.all_tasks) + + def _name_is_registered(self, name: str) -> bool: + if name in self.all_tasks: + return True + return False + + def _name_is_task(self, name: str) -> bool: + if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"): + return True + return False + + def _name_is_tag(self, name: str) -> bool: + if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"): + return True + return False + + def _name_is_group(self, name: str) -> bool: + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "group" + ): + return True + return False + + def _name_is_python_task(self, name: str) -> bool: + if self._name_is_registered(name) and ( + self.task_index[name]["type"] == "python_task" + ): + return True + return False + + def _config_is_task(self, config: dict) -> bool: + if ("task" in config) and isinstance(config["task"], str): + return True + return False + + def _config_is_group(self, config: dict) -> bool: + if ("task" in config) and isinstance(config["task"], list): + return True + return False + + def _config_is_python_task(self, config: dict) -> bool: + if "class" in config: + return True + return False + + def _get_yaml_path(self, name: str): + if name not in self.task_index: + raise ValueError + return self.task_index[name]["yaml_path"] + + def _get_config(self, name): + if name not in self.task_index: + raise ValueError + yaml_path = self._get_yaml_path(name) + if yaml_path == -1: + return {} + else: + return utils.load_yaml_config(yaml_path, mode="full") + + def _get_tasklist(self, name): + if self._name_is_task(name): + raise ValueError + return self.task_index[name]["task"] + + def _process_alias(self, config, group=None): + # If the group is not the same as the original + # group which the group alias was intended for, + # Set the group_alias to None instead. + if ("group_alias" in config) and ("group" in config) and group is not None: + if config["group"] != group: + config["group_alias"] = None + return config + + def _class_has_config_in_constructor(self, cls): + constructor = getattr(cls, "__init__", None) + return ( + "config" in inspect.signature(constructor).parameters + if constructor + else False + ) + + def _load_individual_task_or_group( + self, + name_or_config: Optional[Union[str, dict]] = None, + parent_name: Optional[str] = None, + update_config: Optional[dict] = None, + ) -> Mapping: + def _load_task(config, task): + if "include" in config: + config = { + **utils.load_yaml_config( + yaml_path=None, + yaml_config={"include": config.pop("include")}, + mode="full", + ), + **config, + } + if self._config_is_python_task(config): + if self._class_has_config_in_constructor(config["class"]): + task_object = config["class"](config=config) + else: + task_object = config["class"]() + if isinstance(task_object, ConfigurableTask): + # very scuffed: set task name here. TODO: fixme? + task_object.config.task = task + else: + if self.metadata is not None: + config["metadata"] = config.get("metadata", {}) | self.metadata + else: + config["metadata"] = config.get("metadata", {}) + task_object = ConfigurableTask(config=config) + + return {task: task_object} + + def _get_group_and_subtask_from_config( + config: dict, + ) -> tuple[ConfigurableGroup, list[str]]: + if self.metadata is not None: + config["metadata"] = config.get("metadata", {}) | self.metadata + group_name = ConfigurableGroup(config=config) + subtask_list = [] + for task in group_name.config["task"]: + if isinstance(task, str) and self._name_is_tag(task): + subtask_list.extend(self._get_tasklist(task)) + else: + subtask_list.append(task) + return group_name, subtask_list + + def _process_group_config( + config: dict, update_config: dict = None + ) -> tuple[dict, dict]: + if update_config is not None: + config = {**config, **update_config} + _update_config = { + k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS + } + if not bool(_update_config): + _update_config = None + + group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS} + return group_config, _update_config + + if isinstance(name_or_config, str): + if update_config is not None: + # Process name_or_config as a dict instead + name_or_config = {"task": name_or_config, **update_config} + elif self._name_is_task(name_or_config) or self._name_is_python_task( + name_or_config + ): + task_config = self._get_config(name_or_config) + return _load_task(task_config, task=name_or_config) + else: + subtask_list = self._get_tasklist(name_or_config) + if subtask_list == -1: + group_config = self._get_config(name_or_config) + group_config, update_config = _process_group_config(group_config) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + else: + if self._name_is_tag(name_or_config): + fn = partial( + self._load_individual_task_or_group, + update_config=name_or_config + if isinstance(name_or_config, dict) + else None, + ) + return dict( + collections.ChainMap(*map(fn, reversed(subtask_list))) + ) + else: + group_name = ConfigurableGroup( + config={"group": name_or_config, "task": subtask_list} + ) + + if isinstance(name_or_config, dict): + if self._config_is_task(name_or_config): + name = name_or_config.pop("task") + if update_config is not None: + name_or_config = {**name_or_config, **update_config} + # If the name is registered as a group + if self._name_is_group(name): + group_config = self._get_config(name) + + group_config, update_config = _process_group_config( + group_config, name_or_config + ) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + elif self._name_is_tag(name): + subtask_list = self._get_tasklist(name) + fn = partial( + self._load_individual_task_or_group, + update_config=name_or_config, + ) + return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) + else: + if self._name_is_registered(name): + base_task_config = self._get_config(name) + + # Check if this is a duplicate. + if parent_name is not None: + num_duplicate = len( + list( + filter( + lambda x: x.startswith(name), + self.task_group_map[parent_name], + ) + ) + ) + if num_duplicate > 0: + name = f"{name}-{num_duplicate}" + self.task_group_map[parent_name].append(name) + + task_config = { + **base_task_config, + **name_or_config, + } + else: + task_config = name_or_config + return _load_task(task_config, task=name) + else: + group_config, update_config = _process_group_config(name_or_config) + group_name, subtask_list = _get_group_and_subtask_from_config( + group_config + ) + + fn = partial( + self._load_individual_task_or_group, + parent_name=group_name, + update_config=update_config, + ) + return { + group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) + } + + def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: + """Loads a dictionary of task objects from a list + + :param task_list: Union[str, list] = None + Single string or list of string of task names to be loaded + + :return + Dictionary of task objects + """ + if isinstance(task_list, str): + task_list = [task_list] + + all_loaded_tasks = dict( + collections.ChainMap( + *map( + lambda task: self._load_individual_task_or_group(task), + task_list, + ) + ) + ) + return all_loaded_tasks + + def load_config(self, config: Dict): + return self._load_individual_task_or_group(config) + + def _get_task_and_group(self, task_dir: str): + """Creates a dictionary of tasks index with the following metadata, + - `type`, that can be either `task`, `python_task`, `group` or `tags`. + `task` refer to regular task configs, `python_task` are special + yaml files that only consists of `task` and `class` parameters. + `group` are group configs. `tags` are labels that can be assigned + to tasks to assist in sorting and calling tasks of certain themes. + - `yaml_path`, path to the yaml file. If the entry is a `group` that + was configured through a task config, the yaml_path will be -1 + and all subtasks will be listed in `task` (see below) + - `task`, reserved for entries with `type` as `group`. This will list + all subtasks. When a group config is created (as opposed to task + config having `group` parameter set), this will be set to -1 to + avoid recursive indexing. The whole list of subtasks will be loaded + at evaluation. + + :param task_dir: str + A directory to check for tasks + + :return + Dictionary of task names as key and task metadata + """ + + def _populate_tags_and_groups(config, task, tasks_and_groups, print_info): + # TODO: remove group in next release + if "tag" in config: + attr_list = config["tag"] + if isinstance(attr_list, str): + attr_list = [attr_list] + + for tag in attr_list: + if tag not in tasks_and_groups: + tasks_and_groups[tag] = { + "type": "tag", + "task": [task], + "yaml_path": -1, + } + elif tasks_and_groups[tag]["type"] != "tag": + eval_logger.info( + f"The tag '{tag}' is already registered as a group, this tag will not be registered. " + "This may affect tasks you want to call." + ) + break + else: + tasks_and_groups[tag]["task"].append(task) + + # TODO: remove group in next release + print_info = True + ignore_dirs = [ + "__pycache__", + ".ipynb_checkpoints", + ] + tasks_and_groups = collections.defaultdict() + for root, dirs, file_list in os.walk(task_dir): + dirs[:] = [d for d in dirs if d not in ignore_dirs] + for f in file_list: + if f.endswith(".yaml"): + yaml_path = os.path.join(root, f) + print(yaml_path) + config = utils.load_yaml_config(yaml_path, mode="simple") + if self._config_is_python_task(config): + # This is a python class config + task = config["task"] + tasks_and_groups[task] = { + "type": "python_task", + "yaml_path": yaml_path, + } + _populate_tags_and_groups( + config, task, tasks_and_groups, print_info + ) + elif self._config_is_group(config): + # This is a group config + tasks_and_groups[config["group"]] = { + "type": "group", + "task": -1, # This signals that + # we don't need to know + # the task list for indexing + # as it can be loaded + # when called. + "yaml_path": yaml_path, + } + + # # Registered the level 1 tasks from a group config + # for config in config["task"]: + # if isinstance(config, dict) and self._config_is_task(config): + # task = config["task"] + # tasks_and_groups[task] = { + # "type": "task", + # "yaml_path": yaml_path, + # } + + elif self._config_is_task(config): + # This is a task config + task = config["task"] + tasks_and_groups[task] = { + "type": "task", + "yaml_path": yaml_path, + } + _populate_tags_and_groups( + config, task, tasks_and_groups, print_info + ) + else: + eval_logger.debug(f"File {f} in {root} could not be loaded") + + return tasks_and_groups + + +def get_task_name_from_config(task_config: Dict[str, str]) -> str: + if "task" in task_config: + return task_config["task"] + if "dataset_name" in task_config: + return "{dataset_path}_{dataset_name}".format(**task_config) + else: + return "{dataset_path}".format(**task_config) + + +def get_task_name_from_object(task_object): + if hasattr(task_object, "config"): + return task_object._config["task"] + + # TODO: scrap this + # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting + return ( + task_object.EVAL_HARNESS_NAME + if hasattr(task_object, "EVAL_HARNESS_NAME") + else type(task_object).__name__ + ) + + +def _check_duplicates(task_dict: dict) -> None: + """helper function solely used in validating get_task_dict output. + Takes the output of dllm_eval.evaluator_utils.get_subtask_list and + returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are + "oversubscribed" to several disjoint groups. + """ + subtask_names = [] + for key, value in task_dict.items(): + subtask_names.extend(value) + + duplicate_tasks = { + task_name for task_name in subtask_names if subtask_names.count(task_name) > 1 + } + + # locate the potentially problematic groups that seem to 'compete' for constituent subtasks + competing_groups = [ + group + for group in task_dict.keys() + if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0 + ] + + if len(duplicate_tasks) > 0: + raise ValueError( + f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs." + ) + + +def get_task_dict( + task_name_list: Union[str, List[Union[str, Dict, Task]]], + task_manager: Optional[TaskManager] = None, +): + """Creates a dictionary of task objects from either a name of task, config, or prepared Task object. + + :param task_name_list: List[Union[str, Dict, Task]] + Name of model or LM object, see dllm_eval.models.get_model + :param task_manager: TaskManager = None + A TaskManager object that stores indexed tasks. If not set, + task_manager will load one. This should be set by the user + if there are additional paths that want to be included + via `include_path` + + :return + Dictionary of task objects + """ + + task_name_from_string_dict = {} + task_name_from_config_dict = {} + task_name_from_object_dict = {} + + if isinstance(task_name_list, str): + task_name_list = [task_name_list] + elif isinstance(task_name_list, list): + if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]): + raise TypeError( + "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match." + ) + else: + raise TypeError( + f"Expected a 'str' or 'list' but received {type(task_name_list)}." + ) + + string_task_name_list = [task for task in task_name_list if isinstance(task, str)] + others_task_name_list = [ + task for task in task_name_list if not isinstance(task, str) + ] + if len(string_task_name_list) > 0: + if task_manager is None: + task_manager = TaskManager() + + task_name_from_string_dict = task_manager.load_task_or_group( + string_task_name_list + ) + + for task_element in others_task_name_list: + if isinstance(task_element, dict): + task_name_from_config_dict = { + **task_name_from_config_dict, + **task_manager.load_config(config=task_element), + } + + elif isinstance(task_element, Task): + task_name_from_object_dict = { + **task_name_from_object_dict, + get_task_name_from_object(task_element): task_element, + } + + if not set(task_name_from_string_dict.keys()).isdisjoint( + set(task_name_from_object_dict.keys()) + ): + raise ValueError + + final_task_dict = { + **task_name_from_string_dict, + **task_name_from_config_dict, + **task_name_from_object_dict, + } + + # behavior can get odd if one tries to invoke several groups that "compete" for the same task. + # (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask + # and we'd be unsure which to use and report.) + # we explicitly check and error in this case. + _check_duplicates(get_subtask_list(final_task_dict)) + + return final_task_dict diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/README.md b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25467f905f61ef28883579f54672eab0e7c7dec6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/README.md @@ -0,0 +1,55 @@ +# AIME + +### Citation + +```text +@dataset{aime_1983_2024, + author = {Hemish Veeraboina}, + title = {AIME Problem Set 1983-2024}, + year = {2024}, + publisher = {Kaggle}, + url = {https://www.kaggle.com/datasets/hemishveeraboina/aime-problem-set-1983-2024} +} + +@dataset{aime_2024, + author = {Maxwell Jia}, + title = {AIME Problem Set 2024}, + year = {2024}, + publisher = {Huggingface}, + url = {https://huggingface.co/datasets/Maxwell-Jia/AIME_2024} +} + +@dataset{aime_2025, + author = {math-ai}, + title = {AIME Problem Set 2025}, + year = {2025}, + publisher = {Huggingface}, + url = {https://huggingface.co/datasets/math-ai/aime25} +} +``` + +### Groups, Tags, and Tasks + +#### Groups + +* `math_word_problems` + +#### Tasks + +* `aime`: `AIME 1983-2024 problems` +* `aime24`: `AIME 2024 problems` +* `aime25`: `AIME 2025 problems` + +### Checklist + +For adding novel benchmarks/datasets to the library: + +* [x] Is the task an existing benchmark in the literature? + * [x] Have you referenced the original paper that introduced the task? + * [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test? + +If other tasks on this dataset are already supported: + +* [ ] Is the "Main" variant of this task clearly denoted? +* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? +* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime.yaml b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9a6cced3adcc8f8918e55c49fbc92eeda2b7623 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime.yaml @@ -0,0 +1,29 @@ +tag: + - math_word_problems +task: aime +dataset_path: gneubig/aime-1983-2024 +# dataset_name: null +output_type: generate_until +training_split: train +fewshot_split: train +test_split: train +doc_to_text: "Question: {{Question}}\nAnswer:" +doc_to_target: "{{Answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Question:" + - "" + - "<|im_end|>" + - "<|eot_id|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 32768 +repeats: 1 +num_fewshot: 0 +metadata: + version: 0.0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime24.yaml b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime24.yaml new file mode 100644 index 0000000000000000000000000000000000000000..714596912615b5c16d4708e21f0eb56b33959754 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime24.yaml @@ -0,0 +1,29 @@ +tag: + - math_word_problems +task: aime24 +dataset_path: Maxwell-Jia/AIME_2024 +# dataset_name: null +output_type: generate_until +training_split: train +fewshot_split: train +test_split: train +doc_to_text: "Question: {{Problem}}\nAnswer:" +doc_to_target: "{{Answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Question:" + - "" + - "<|im_end|>" + - "<|eot_id|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 32768 +repeats: 1 +num_fewshot: 0 +metadata: + version: 0.0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime25.yaml b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime25.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ef64005863674f7afc5c76b8cdff22d224ae2da --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/aime25.yaml @@ -0,0 +1,29 @@ +tag: + - math_word_problems +task: aime25 +dataset_path: math-ai/aime25 +# dataset_name: null +output_type: generate_until +training_split: test +fewshot_split: test +test_split: test +doc_to_text: "Question: {{problem}}\nAnswer:" +doc_to_target: "{{answer}}" +process_results: !function utils.process_results +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +generation_kwargs: + until: + - "Question:" + - "" + - "<|im_end|>" + - "<|eot_id|>" + do_sample: false + temperature: 0.0 + max_gen_toks: 32768 +repeats: 1 +num_fewshot: 0 +metadata: + version: 0.0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f668c23bc18d646c16390302ad24cc3ced1aa3b4 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/aime/utils.py @@ -0,0 +1,231 @@ +import re +from typing import Dict, List + + +def process_results(doc: dict, results: List[str]) -> Dict[str, int]: + retval = 0 + response = results[0] + + # Try to extract answer from $...$ format first + indices = [pos for pos, char in enumerate(response) if char == "$"] + if len(indices) <= 1: + answer = response + else: + answer = response[indices[0] + 1 : indices[-1]] + + # Extract from \\boxed{} if present + boxed_answer = last_boxed_only_string(response) + if boxed_answer is not None: + try: + boxed_content = remove_boxed(boxed_answer) + if boxed_content is not None: + answer = boxed_content + except (AssertionError, IndexError): + pass + + # Check if answer matches target + answer_key = next(k for k in doc.keys() if k.lower() == "answer") + target = str(doc[answer_key]) + if is_equiv(answer, target): + retval = 1 + + return {"exact_match": retval} + + +# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print("WARNING: Both None") + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 + + +def remove_boxed(s): + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +def last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except AssertionError: + return string + + +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = fix_a_slash_b(string) + + return string diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/gsm8k/gsm8k.yaml b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/gsm8k/gsm8k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c56206923cf19bac4ec07233c6b0b17ac0460ad --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/gsm8k/gsm8k.yaml @@ -0,0 +1,15 @@ +task: gsm8k +dataset_path: openai/gsm8k +dataset_name: main +output_type: generate_until +training_split: train +fewshot_split: train +test_split: test +doc_to_text: !function utils.gsm_prompt +doc_to_target: "{{answer.split('####')[-1].strip()}}" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/gsm8k/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/gsm8k/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ceaa3d2ab7af89f27e69b470a2f6787f6133519 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/gsm8k/utils.py @@ -0,0 +1,13 @@ +def gsm_prompt(doc): + system_prompt = ( + "You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n" + "Respond in the following format:\n" + "\n" + "Your reasoning here\n" + "\n" + "\n" + "\\boxed{...}\n" + "" + ) + prompt = f"{system_prompt}\n\n{doc['question']}\n\n" + return prompt diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/humaneval/humaneval.yaml b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/humaneval/humaneval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..024d38f0da160e853cd8c3123104a4485677c0fd --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/humaneval/humaneval.yaml @@ -0,0 +1,13 @@ +task: humaneval +dataset_path: openai/openai_humaneval +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "Write a solution to the following problem and make sure that it passes the tests:\n{{prompt}}\n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n\nYour reasoning here\n\n\n```python\nThe complete implementation of the {{entry_point}} function\n```\n" +doc_to_target: "{{test}}\ncheck({{entry_point}})" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/humaneval/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/humaneval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11bac61cfa12fad57aacfed28b55bee467cf23e4 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/humaneval/utils.py @@ -0,0 +1,43 @@ +import evaluate as hf_evaluate + + +try: + compute_ = hf_evaluate.load("code_eval") + test_cases = ["assert add(2, 3)==5"] + candidates = [["def add(a,b): return a*b"]] + results = compute_.compute(references=test_cases, predictions=candidates, k=[1]) +except Exception as e: + raise e + + +def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None): + global compute_ + assert k is not None + if isinstance(k, int): + k = [k] + res = compute_.compute( + references=references, + predictions=predictions, + k=k + ) + return res[0] + + +def clean_response_string(r: str) -> str: + cleaned_text = r if r.rfind("```python") == -1 else r[r.rfind("```python"):] + cleaned_text = cleaned_text if cleaned_text.rfind("```") == -1 else cleaned_text[: cleaned_text.rfind("```")] + cleaned_text = cleaned_text if cleaned_text.rfind("if __name__ == \"__main__\":") == -1 else cleaned_text[: cleaned_text.rfind("if __name__ == \"__main__\":")] + return cleaned_text + + +def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)] + + +def build_predictions( + resps: list[list[str]], docs: list[dict] +) -> list[list[str]]: + return [ + [clean_response_string(r) for r in resp] + for resp, doc in zip(resps, docs) + ] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/math500/math500.yaml b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/math500/math500.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1fe2f7a38417fe863c1301953be514b618054707 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/math500/math500.yaml @@ -0,0 +1,12 @@ +task: math500 +dataset_path: HuggingFaceH4/MATH-500 +output_type: generate_until +test_split: test +doc_to_text: !function utils.math500_prompt +doc_to_target: "{{answer}}" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +repeats: 1 +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/math500/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/math500/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0585298c29c8b5c12ebeaa01dfff572267db601 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/math500/utils.py @@ -0,0 +1,14 @@ +def math500_prompt(doc): + system_prompt = ( + "You are a math expert. You will be given a question to solve. Solve it step by step. Wrap the final answer in a \\boxed{}. \n" + "Respond in the following format:\n" + "\n" + "Your reasoning here\n" + "\n" + "\n" + "\\boxed{...}\n" + "" + ) + + prompt = f"{system_prompt}\n\n{doc['problem']}\n\n" + return prompt diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/mbpp/mbpp.yaml b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/mbpp/mbpp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5b9755ad30669e2335bd374ba5f53db0572630f --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/mbpp/mbpp.yaml @@ -0,0 +1,14 @@ +task: mbpp +dataset_path: google-research-datasets/mbpp +dataset_name: full +unsafe_code: true +output_type: generate_until +test_split: test +doc_to_text: "\n{{text}} Your code should pass these tests:\n\n{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}} \n\nFirst, reason about the solution step-by-step. Then, write the code.\nRespond in the following format:\n\nYour reasoning here\n\n\n```python\nThe complete implementation of the function\n```\n" +doc_to_target: "{% if is_fewshot is defined %}{{code}}\n[DONE]{% else %}{{test_list[0]}}\n{{test_list[1]}}\n{{test_list[2]}}{% endif %}" +target_delimiter: "" +generation_kwargs: + until: + - "[NO_UNTIL_PLACEHOLDER]" + do_sample: false +num_fewshot: 0 diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/mbpp/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/mbpp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..234fc7ed5de047e556dea2ff77d02a232c8f3e6e --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/tasks/mbpp/utils.py @@ -0,0 +1,79 @@ +import re +from typing import Union + +import evaluate as hf_evaluate + + +try: + pass_at_k = hf_evaluate.load("code_eval") + + # run simple test to check code execution is enabled before model generation + test_cases = ["assert add(2, 3)==5"] + candidates = [["def add(a,b): return a*b"]] + results = pass_at_k.compute(references=test_cases, predictions=candidates, k=[1]) +except Exception as e: + raise e + + +def pass_at_1( + references: Union[str, list[str]], predictions: Union[str, list[list[str]]] +) -> float: + if isinstance(references, str): + references = [references] + if isinstance(predictions[0], str): + predictions = [[p] for p in predictions] + return pass_at_k.compute( + references=references, + predictions=predictions, + k=[1], + num_workers=48 + )[0]["pass@1"] + + +def extract_code_blocks(text: str) -> str: + text = re.sub(r"\[DONE\]", "", text) + text = re.sub(r"<\|eot_id\|>", "", text) + text = re.sub(r"<\|endoftext\|>", "", text) + return text + + +def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]: + return [[extract_code_blocks(r) for r in resp] for resp in resps] + + +def list_fewshot_samples(): + return [ + { + "task_id": 2, + "text": "Write a function to find the similar elements from the given two tuple lists.", + "code": "def similar_elements(test_tup1, test_tup2):\r\n res = tuple(set(test_tup1) & set(test_tup2))\r\n return (res) ", + "test_list": [ + "assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)", + "assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)", + "assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)", + ], + "is_fewshot": True, + }, + { + "task_id": 3, + "text": "Write a python function to identify non-prime numbers.", + "code": "import math\r\ndef is_not_prime(n):\r\n result = False\r\n for i in range(2,int(math.sqrt(n)) + 1):\r\n if n % i == 0:\r\n result = True\r\n return result", + "test_list": [ + "assert is_not_prime(2) == False", + "assert is_not_prime(10) == True", + "assert is_not_prime(35) == True", + ], + "is_fewshot": True, + }, + { + "task_id": 4, + "text": "Write a function to find the largest integers from a given list of numbers using heap queue algorithm.", + "code": "import heapq as hq\r\ndef heap_queue_largest(nums,n):\r\n largest_nums = hq.nlargest(n, nums)\r\n return largest_nums", + "test_list": [ + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],2)==[85, 75] ", + "assert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],5)==[85, 75, 65, 58, 35]", + ], + "is_fewshot": True, + }, + ] diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/utils.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d75d370a30862ba13dc3905f39031f631d70e8fd --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/dllm_eval/utils.py @@ -0,0 +1,552 @@ +import collections +import fnmatch +import functools +import hashlib +import importlib.util +import inspect +import json +import logging +import os +import re +from dataclasses import asdict, is_dataclass +from itertools import islice +from pathlib import Path +from typing import Any, Callable, Generator, List, Optional, Tuple + +import numpy as np +import yaml +from jinja2 import BaseLoader, Environment, StrictUndefined + + +SPACING = " " * 47 + +HIGHER_IS_BETTER_SYMBOLS = { + True: "↑", + False: "↓", +} + + +def setup_logging(verbosity=logging.INFO): + # Configure the root logger + class CustomFormatter(logging.Formatter): + def format(self, record): + if record.name.startswith("dllm_eval."): + record.name = record.name[len("dllm_eval.") :] + return super().format(record) + + formatter = CustomFormatter( + "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d:%H:%M:%S", + ) + + log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity + + level_map = { + "DEBUG": logging.DEBUG, + "INFO": logging.INFO, + "WARNING": logging.WARNING, + "ERROR": logging.ERROR, + "CRITICAL": logging.CRITICAL, + } + + log_level = level_map.get(str(log_level).upper(), logging.INFO) + + if not logging.root.handlers: + handler = logging.StreamHandler() + handler.setFormatter(formatter) + + root_logger = logging.getLogger() + root_logger.addHandler(handler) + root_logger.setLevel(log_level) + + if log_level == logging.DEBUG: + third_party_loggers = ["urllib3", "filelock", "fsspec"] + for logger_name in third_party_loggers: + logging.getLogger(logger_name).setLevel(logging.INFO) + else: + logging.getLogger().setLevel(log_level) + + +def hash_string(string: str) -> str: + return hashlib.sha256(string.encode("utf-8")).hexdigest() + + +def escaped_split(text, sep_char, maxsplit=-1): + """Split text into a list on occurrences of the given separation + character `sep_char`. The separation character may be escaped by a + backslash to avoid splitting at that location. + + The separation character must be a string of size 1. + + If `maxsplit` is given, at most `maxsplit` splits are done (thus, + the list will have at most `maxsplit + 1` elements). If `maxsplit` + is not specified or less than 0, then there is no limit on the + number of splits (all possible splits are made). + """ + assert len(sep_char) == 1, ( + "separation string must be a single character for escaped splitting" + ) + + if maxsplit == 0: + return text + maxsplit = max(0, maxsplit) + + return re.split(r"(? dict: + """ + Parses something like + args1=val1,arg2=val2 + Into a dictionary + """ + if args_string is None: + return {} + args_string = args_string.strip() + if not args_string: + return {} + arg_list = [arg for arg in args_string.split(",") if arg] + args_dict = { + kv[0]: handle_arg_string("=".join(kv[1:])) + for kv in [arg.split("=") for arg in arg_list] + } + return args_dict + + +def join_iters(iters): + for iter in iters: + yield from iter + + +def group(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + + return list(res.values()) + + +# Returns a list containing all values of the source_list that +# match at least one of the patterns +def pattern_match(patterns, source_list): + if isinstance(patterns, str): + patterns = [patterns] + + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return sorted(list(task_names)) + + +def softmax(x) -> np.ndarray: + """Compute softmax values for each sets of scores in x.""" + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + +def general_detokenize(string) -> str: + string = string.replace(" n't", "n't") + string = string.replace(" )", ")") + string = string.replace("( ", "(") + string = string.replace('" ', '"') + string = string.replace(' "', '"') + string = re.sub(r" (['.,])", r"\1", string) + return string + + +def get_file_task_name(filename: str) -> str: + """ + Given the sample results filenames, extracts and returns the task name. + """ + return filename[filename.find("_") + 1 : filename.rfind("_")] + + +def get_file_datetime(filename: str) -> str: + """ + Given the results and sample results filenames, extracts and returns the datetime. + """ + return filename[filename.rfind("_") + 1 :].replace(".jsonl", "") + + +def sanitize_model_name(model_name: str) -> str: + """ + Given the model name, returns a sanitized version of it. + """ + return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) + + +def sanitize_task_name(task_name: str) -> str: + """ + Given the task name, returns a sanitized version of it. + """ + return re.sub(r"\W", "_", task_name) + + +def get_latest_filename(filenames: List[str]) -> str: + """ + Given a list of filenames, returns the filename with the latest datetime. + """ + return max(filenames, key=lambda f: get_file_datetime(f)) + + +def get_results_filenames(filenames: List[str]) -> List[str]: + """ + Extracts filenames that correspond to aggregated results. + """ + return [f for f in filenames if "/results_" in f and ".json" in f] + + +def get_sample_results_filenames(filenames: List[str]) -> List[str]: + """ + Extracts filenames that correspond to sample results. + """ + return [f for f in filenames if "/samples_" in f and ".json" in f] + + +def get_rolling_token_windows( + token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int +) -> Generator[Tuple[List[int], List[int]], None, None]: + """ + - context_len allows for a rolling window context, allowing each prediction window to potentially + condition on some context + + :param token_list: list + List of tokens to be PREDICTED + :param max_seq_len: int + max_seq_len of model (or max_seq_len we want to use) + :param context_len: int + Amount of desired token context for prediction. Needs to be at least 1. + :param prefix_token: token + Dummy token like so the first token has something to condition on + :return: generator + Generator of tuples + (input_tokens, pred_tokens) + Note: Score only the last len(pred_tokens) logits of the LM + """ + assert 1 <= context_len <= max_seq_len + if not token_list: + return + # +1 offset, going from input->preds + pred_len = max_seq_len - context_len + 1 + predicted = 0 + + # Special handling for first window: predict all tokens + first_seq_len = min(max_seq_len, len(token_list)) + yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len] + predicted += first_seq_len + + while predicted < len(token_list): + window_pred_len = min(len(token_list) - predicted, pred_len) + window_end = predicted + window_pred_len + + yield ( + token_list[window_end - max_seq_len - 1 : window_end - 1], + token_list[window_end - window_pred_len : window_end], + ) + predicted += window_pred_len + + +def make_disjoint_window( + pair: Tuple[List[int], List[int]], +) -> Tuple[List[int], List[int]]: + """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" + a, b = pair + return a[: len(a) - (len(b) - 1)], b + + +class EnhancedJSONEncoder(json.JSONEncoder): + """ + Provides a proper json encoding for the loggers and trackers json dumps. + Notably manages the json encoding of dataclasses. + """ + + def default(self, o): + if is_dataclass(o): + return asdict(o) + return super().default(o) + + +class Reorderer: + def __init__(self, arr: List[Any], fn: Callable) -> None: + """Reorder an array according to some function + + Args: + arr (List[Any]): The initial array + fn (Callable[[Any], Any]): A function to determine the priority of elements + """ + self.size = len(arr) + arr = list(enumerate(arr)) + arr = group(arr, lambda x: fn(x[1])) + # arr = [([y[0] for y in x], x[0][1]) for x in arr] + # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this + arr = [([y[0]], x[0][1]) for x in arr for y in x] + arr.sort(key=lambda x: fn(x[1])) + + self.arr = arr + + def get_reordered(self): + """Gets the reordered array + + Returns: + List[Any]: The reordered array + """ + return [x[1] for x in self.arr] + + def get_original(self, newarr): + """Restores the original order of a new array based on the old array's order + + Args: + newarr (List[Any]): The array to be restored + + Returns: + List[Any]: The array restored to the original order + """ + res = [None] * self.size + cov = [False] * self.size + + for (inds, _), v in zip(self.arr, newarr): + for ind in inds: + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + + +def make_table(result_dict, column: str = "results", sort_results: bool = False): + """Generate table of results.""" + from pytablewriter import LatexTableWriter, MarkdownTableWriter + + if column == "results": + column_name = "Tasks" + elif column == "groups": + column_name = "Groups" + + all_headers = [ + column_name, + "Version", + "Filter", + "n-shot", + "Metric", + "", + "Value", + "", + "Stderr", + ] + + md_writer = MarkdownTableWriter() + latex_writer = LatexTableWriter() + md_writer.headers = all_headers + latex_writer.headers = all_headers + + values = [] + + keys = result_dict[column].keys() + if sort_results: + # sort entries alphabetically by task or group name. + # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu. + # sorting here would mess that up + keys = sorted(keys) + for k in keys: + dic = result_dict[column][k] + version = result_dict["versions"].get(k, " N/A") + n = str(result_dict.get("n-shot", " ").get(k, " ")) + higher_is_better = result_dict.get("higher_is_better", {}).get(k, {}) + + if "alias" in dic: + k = dic.pop("alias") + + metric_items = dic.items() + metric_items = sorted(metric_items) + + for (mf), v in metric_items: + m, _, f = mf.partition(",") + if m.endswith("_stderr"): + continue + + hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "") + + v = "%.4f" % v if isinstance(v, float) else v + + if m + "_stderr" + "," + f in dic: + se = dic[m + "_stderr" + "," + f] + se = " N/A" if se == "N/A" else "%.4f" % se + values.append([k, version, f, n, m, hib, v, "±", se]) + else: + values.append([k, version, f, n, m, hib, v, "", ""]) + k = "" + version = "" + md_writer.value_matrix = values + latex_writer.value_matrix = values + + # todo: make latex table look good + # print(latex_writer.dumps()) + + return md_writer.dumps() + + +def positional_deprecated(fn): + """ + A decorator to nudge users into passing only keyword args (`kwargs`) to the + wrapped function, `fn`. + """ + + @functools.wraps(fn) + def _wrapper(*args, **kwargs): + if len(args) != 1 if inspect.ismethod(fn) else 0: + print( + f"WARNING: using {fn.__name__} with positional arguments is " + "deprecated and will be disallowed in a future version of " + "lm-evaluation-harness!" + ) + return fn(*args, **kwargs) + + return _wrapper + + +def ignore_constructor(loader, node): + return node + + +def import_function(loader: yaml.Loader, node, yaml_path: Path): + function_name = loader.construct_scalar(node) + + *module_name, function_name = function_name.split(".") + if isinstance(module_name, list): + module_name = ".".join(module_name) + module_path = yaml_path.parent / f"{module_name}.py" + + spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix()) + + if spec is None: + raise ImportError(f"Could not import module {module_name} from {module_path}.") + module = importlib.util.module_from_spec(spec) + + if spec.loader is None: + raise ImportError(f"Module loader is None, {module_name} from {module_path}.") + spec.loader.exec_module(module) + + function = getattr(module, function_name) + return function + + +def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): + if mode == "simple": + constructor_fn = ignore_constructor + elif mode == "full": + if yaml_path is None: + raise ValueError("yaml_path must be provided if mode is 'full'.") + # Attach yaml_path to the import function so that it can be used later + constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path)) + + loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader + # Add the import_function constructor to the YAML loader + yaml.add_constructor("!function", constructor_fn, Loader=loader) + if yaml_config is None: + with open(yaml_path, "rb") as file: + yaml_config = yaml.load(file, Loader=loader) + + if yaml_dir is None: + yaml_dir = os.path.dirname(yaml_path) + + assert yaml_dir is not None + + if "include" in yaml_config: + include_path = yaml_config["include"] + del yaml_config["include"] + + if isinstance(include_path, str): + include_path = [include_path] + + # Load from the last one first + include_path.reverse() + final_yaml_config = {} + for path in include_path: + # Assumes that path is a full path. + # If not found, assume the included yaml + # is in the same dir as the original yaml + if not os.path.isfile(path): + path = os.path.join(yaml_dir, path) + + try: + included_yaml_config = load_yaml_config(yaml_path=path, mode=mode) + final_yaml_config.update(included_yaml_config) + except Exception as ex: + # If failed to load, ignore + raise ex + + final_yaml_config.update(yaml_config) + return final_yaml_config + return yaml_config + + +def regex_replace(string, pattern, repl, count: int = 0): + """Implements the `re.sub` function as a custom Jinja filter.""" + return re.sub(pattern, repl, string, count=count) + + +env = Environment( + loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True +) +env.filters["regex_replace"] = regex_replace + + +def apply_template(template: str, doc: dict) -> str: + rtemplate = env.from_string(template) + return rtemplate.render(**doc) + + +def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): + """ + Method for creating a (potentially) sliced and limited + iterator from a raw document iterator. Used for splitting data + among ranks in multigpu setting or only pulling a sample of documents + """ + return islice(raw_iterator, rank, limit, world_size) + + +def weighted_f1_score(items): + from sklearn.metrics import f1_score + + unzipped_list = list(zip(*items)) + golds = unzipped_list[0] + preds = unzipped_list[1] + fscore = f1_score(golds, preds, average="weighted") + return fscore diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/evaluation_script.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/evaluation_script.py new file mode 100644 index 0000000000000000000000000000000000000000..0c90bd0c9c7ebd1f15b77158670a3858f5532468 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/evaluation_script.py @@ -0,0 +1,21 @@ +import os +import torch +import random +import numpy as np +from dllm_eval.__main__ import cli_evaluate + + +def set_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +if __name__ == "__main__": + os.environ["HF_ALLOW_CODE_EVAL"] = "1" + os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" + set_seed(42) + cli_evaluate() \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/gsm8k_all.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/gsm8k_all.py new file mode 100644 index 0000000000000000000000000000000000000000..7133a935166c211bf8a8f2e535ae7e1bd54061b6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/gsm8k_all.py @@ -0,0 +1,286 @@ +import json +import re +import os +import math +import argparse +from collections import Counter + +RES_PATH = "" + +def last_boxed_only_string(string): + if not string: return None + idx = max(string.rfind("\\boxed"), string.rfind("\\fbox")) + if idx < 0: return None + + if "\\boxed " in string[idx:idx+8] and "{" not in string[idx:idx+8]: + return "\\boxed " + string[idx:].split("\\boxed ")[-1].split("$")[0].strip() + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + return string[idx : right_brace_idx + 1] if right_brace_idx else None + +def remove_boxed(s): + if not s: return None + if "\\boxed " in s: return s[len("\\boxed ") :] + if "\\boxed{" in s and s.endswith("}"): return s[len("\\boxed{") : -1] + if "\\fbox{" in s and s.endswith("}"): return s[len("\\fbox{") : -1] + return s + +def strip_string(string): + if string is None: return "" + string = str(string).strip() + while re.search(r"(\d),(\d{3})", string): + string = re.sub(r"(\d),(\d{3})", r"\1\2", string) + + string = string.replace("\n", "").replace("\\!", "") + string = string.replace("tfrac", "frac").replace("dfrac", "frac") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("^{\\circ}", "").replace("^\\circ", "") + string = string.replace("\\$", "").replace("\\%", "").replace("\%", "") + + if "=" in string and len(string.split("=")[0]) <= 5: + string = string.split("=")[1].strip() + + string = string.replace(" ", "") + string = string.rstrip(".") + return string + +def normalize_to_number(s): + s_clean = strip_string(s) + try: + if '/' in s_clean and len(s_clean.split('/')) == 2: + parts = s_clean.split('/') + return float(parts[0]) / float(parts[1]) + return float(s_clean) + except: + return s_clean + +def extract_answer_gsm8k_debug(text): + if not text: return "", "empty" + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip() + + boxed = last_boxed_only_string(text) + if boxed: + ans = remove_boxed(boxed) + if ans: + return strip_string(ans), "boxed" + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + return strip_string(tag_match.group(1)), "xml_tag" + + last_text = text[-200:] if len(text) > 200 else text + marker = "the answer is" + if marker in last_text.lower(): + idx = last_text.lower().rfind(marker) + after = last_text[idx + len(marker):].strip() + after = re.split(r"[.\n]", after)[0] + after = after.replace(":", "").replace("$", "").strip() + return strip_string(after), "text_marker" + + tail = text[-50:] + nums = re.findall(r"(?>> 正在评测: {file_path}") + detailed_results = [] + + correct_voted_count = 0 + correct_any_count = 0 + total_count = 0 + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + try: + item = json.loads(line) + except: + continue + + doc = item.get("doc", {}) + ground_truth = extract_gold_gsm8k(str(item.get("target", ""))) + + total_nfe_item = item.get("nfe", 0) + nfe_list.append(total_nfe_item) + svf_list.append(item.get("svf_calls", 0)) + + trajectories = item.get("all_trajectories", []) + if not trajectories: + resps = item.get("resps", []) + for r in resps: + text = r[0] if isinstance(r, list) else r + trajectories.append({"resp": text, "score": 0.0}) + + parsed_paths = [] + traj_debug_info = [] + + for idx, traj in enumerate(trajectories): + raw_text = traj.get("resp", "") + score = traj.get("score", 0.0) + + extracted, method = extract_answer_gsm8k_debug(raw_text) + + is_correct_single = False + if extracted: + is_correct_single = is_equiv(extracted, ground_truth) + val_key = normalize_to_number(extracted) + + parsed_paths.append({ + "original_text": extracted, + "val_key": val_key, + "score": score, + "method": method + }) + + traj_debug_info.append({ + "id": idx, + "extracted": extracted, + "score": score, + "is_correct": is_correct_single, + "extract_method": method + }) + + if not parsed_paths: + detailed_results.append({ + "question": doc.get("question", "N/A"), + "final_voted_answer": "", + "ground_truth": ground_truth, + "is_voted_correct": False, + "trajectory_details": traj_debug_info, + "nfe": total_nfe_item, + "svf_calls": item.get("svf_calls", 0) + }) + total_count += 1 + continue + + has_correct = any(p['score'] > -999 and is_equiv(p['original_text'], ground_truth) for p in parsed_paths) + if has_correct: + correct_any_count += 1 + + parsed_paths.sort(key=lambda x: x['score'], reverse=True) + top_k_count = max(1, int(len(parsed_paths) * 0.6)) + voting_candidates = parsed_paths[:top_k_count] + + ans_stats = {} + for p in voting_candidates: + k = p['val_key'] + if k not in ans_stats: + ans_stats[k] = { + "total_weight": 0.0, + "count": 0, + "max_score": -float('inf'), + "best_repr": p['original_text'] + } + + try: + weight = math.exp(p['score']) + except OverflowError: + weight = float('inf') + + ans_stats[k]["total_weight"] += weight + ans_stats[k]["count"] += 1 + if p['score'] > ans_stats[k]["max_score"]: + ans_stats[k]["max_score"] = p['score'] + ans_stats[k]["best_repr"] = p['original_text'] + + sorted_answers = sorted( + ans_stats.items(), + key=lambda x: (x[1]["total_weight"], x[1]["max_score"]), + reverse=True + ) + + best_pred = str(sorted_answers[0][1]["best_repr"]) + is_voted_correct = is_equiv(best_pred, ground_truth) + if is_voted_correct: + correct_voted_count += 1 + + vote_summary = [] + for val, info in sorted_answers: + vote_summary.append({ + "answer": str(val), + "count": info["count"], + "total_weight": info["total_weight"], + "is_correct": is_equiv(str(val), ground_truth) + }) + + total_count += 1 + + detailed_results.append({ + "question": doc.get("question", "N/A"), + "final_voted_answer": best_pred, + "ground_truth": ground_truth, + "is_voted_correct": is_voted_correct, + "vote_stats": vote_summary, + "trajectory_details": traj_debug_info, + "nfe": total_nfe_item, + "svf_calls": item.get("svf_calls", 0) + }) + + accuracy = (correct_voted_count / total_count * 100) if total_count > 0 else 0 + pass_at_k = (correct_any_count / total_count * 100) if total_count > 0 else 0 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + final_report = { + "summary": { + "accuracy": f"{accuracy:.2f}%", + "correct_voted": correct_voted_count, + "total": total_count, + "nfe": avg_nfe, + "svf_calls": avg_svf + }, + "details": detailed_results + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/humaneval_all.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/humaneval_all.py new file mode 100644 index 0000000000000000000000000000000000000000..842a77c8938d7de95247e6f153e42b00625dea99 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/humaneval_all.py @@ -0,0 +1,183 @@ +import os +import sys +import json +import ast +import traceback +import glob +import math +import argparse +from typing import Dict, List, Optional, Set, Tuple +from collections import Counter +import evaluate as hf_evaluate +import re + +RES_PATH = "" + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" + +def extract_python_code(text: str) -> str: + if not text: return "" + + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "") + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + text = tag_match.group(1) + + if "```python" in text: + content = text.split("```python")[-1] + if "```" in content: + return content.split("```")[0].strip() + return content.strip() + elif "```" in text: + content = text.split("```")[-1] + if "```" in content: + return content.split("```")[0].strip() + return content.strip() + + lines = text.split('\n') + cleaned_lines = [] + stop_words = ["Explanation:", "Example:", "Test Case:", "Output:"] + for line in lines: + if any(sw in line for sw in stop_words): + break + cleaned_lines.append(line) + + return "\n".join(cleaned_lines).strip() + +def normalize_code_for_voting(code: str) -> str: + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (node.body and isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)): + node.body.pop(0) + return ast.unparse(tree).strip() + except: + return re.sub(r"\s+", "", code) + +def sanitize(prompt: str, completion: str, entrypoint: str) -> str: + if f"def {entrypoint}" in completion: + return completion + return prompt + "\n" + completion + +def run_evaluation(target_path): + if os.path.isdir(target_path): + jsonl_files = glob.glob(os.path.join(target_path, "**/*.jsonl"), recursive=True) + else: + jsonl_files = [target_path] + + if not jsonl_files: + print(f"未在路径 {target_path} 下找到任何 .jsonl 文件") + return + + print(f"共找到 {len(jsonl_files)} 个评测任务") + code_eval = hf_evaluate.load("code_eval") + + for file_path in jsonl_files: + print(f"\n>>> 正在评测: {file_path}") + all_predictions = [] + all_references = [] + detailed_results = [] + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + if not lines: continue + + for line in lines: + if not line.strip(): continue + item = json.loads(line) + doc = item.get("doc", {}) + prompt = doc.get("prompt", "") + entry_point = doc.get("entry_point", "") + reference = doc.get("test", "") + + current_nfe = item.get("nfe", 0) + nfe_list.append(current_nfe) + svf_list.append(item.get("svf_calls", 0)) + + resps = item.get("resps", []) + candidate_stats = {} + + for r in resps: + raw_text = r[0] if isinstance(r, list) else r + completion = extract_python_code(raw_text) + full_code = sanitize(prompt, completion, entry_point) + + try: + ast.parse(full_code) + is_valid = True + except: + is_valid = False + + logic_norm = normalize_code_for_voting(full_code) + if not logic_norm: continue + + if logic_norm not in candidate_stats: + candidate_stats[logic_norm] = {"count": 0, "valid": is_valid, "code": full_code} + candidate_stats[logic_norm]["count"] += 1 + + if not candidate_stats: + voted_code = prompt + else: + sorted_logics = sorted( + candidate_stats.keys(), + key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["count"]), + reverse=True + ) + voted_code = candidate_stats[sorted_logics[0]]["code"] + + all_predictions.append([voted_code]) + all_references.append(reference) + detailed_results.append({ + "task_id": doc.get("task_id", doc.get("name", "N/A")), + "voted_code": voted_code, + "nfe": current_nfe, + "svf_calls": item.get("svf_calls", 0), + "candidates_count": len(candidate_stats) + }) + + if not all_predictions: continue + + print(f"正在执行代码测试 (共 {len(all_predictions)} 题)...") + pass_at_k, exec_results = code_eval.compute( + references=all_references, + predictions=all_predictions, + k=[1], + num_workers=4 + ) + + accuracy = pass_at_k.get("pass@1", 0.0) * 100 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- 结果: Accuracy: {accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + for i, detail in enumerate(detailed_results): + res_list = exec_results.get(i, []) + detail["is_correct"] = res_list[0][1]["passed"] if res_list else False + + final_report = { + "summary": { + "accuracy": f"{accuracy:.2f}%", + "nfe": avg_nfe, + "svf_calls": avg_svf + }, + "details": detailed_results + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + print(f"报告已保存至: {output_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/math500_all.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/math500_all.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7d8671623f3e848ebca1c5836928185179f5fb --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/math500_all.py @@ -0,0 +1,213 @@ +import json +import re +import os +import math +import argparse +from collections import Counter + +RES_PATH = "" + +def extract_answer(text): + if not text: + return "", False + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").strip() + + boxed_pattern = r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}" + all_boxes = re.findall(boxed_pattern, text) + if all_boxes: + return all_boxes[-1], True + + tag_match = re.search(r"(.*?)", text, re.DOTALL) + if tag_match: + return tag_match.group(1).strip(), True + + marker = "the answer is" + if marker in text.lower(): + pos = text.lower().rfind(marker) + after_text = text[pos + len(marker):].strip() + after_text = re.sub(r"^[:\s]+", "", after_text) + return after_text.split('\n')[0].split('$')[0].strip(), True + + tail = text[-50:].strip() + nums = re.findall(r"(-?\d+[\./\d]*|\\sqrt\{\d+\}|\(-?\d+.*?\))", tail) + if nums: + return nums[-1], False + return "", False + +def normalize_math(string): + if not string: return "" + string = str(string).lower().strip() + + string = string.replace("", "").replace("", "").replace("", "") + string = string.replace("...", "").replace("cannot be determined", "") + + string = re.sub(r"([a-z]+|\\theta|\\alpha|\\pi)\s*=\s*", "", string) + string = re.sub(r"\\text\{([^}]*)\}", r"\1", string) + string = re.sub(r"\\(mathbf|mathrm|bold|unit|mbox|operatorname|mathrm)\{([^}]*)\}", r"\2", string) + string = re.sub(r"\\(d|t)?frac\{([^{}]*)\}\{([^{}]*)\}", r"\2/\3", string) + string = string.replace("\\!", "").replace("\\ ", "").replace("{", "").replace("}", "") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("\\$", "").replace("$", "").replace("\\%", "").replace("%", "") + + units_pattern = r"(units?|cm\^2|cm|inches|inch|square|degrees?|radians?|miles?|per|hour|cents?)" + string = re.sub(units_pattern, "", string) + string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("°", "").replace("\\degree", "") + string = string.replace("\\pi", "pi") + string = re.sub(r"(\d),(\d{3})", r"\1\2", string) + string = string.rstrip(".:,; ").replace(" ", "") + + if "=" in string: + string = string.split("=")[-1] + + return string + +def is_equiv(pred, gold): + if not pred: return False + p, g = normalize_math(pred), normalize_math(gold) + if p == g: return True + + if "=" in pred: + if normalize_math(pred.split("=")[-1]) == g: + return True + + try: + def to_float(s): + if '/' in s and s.count('/') == 1: + parts = s.split('/') + return float(parts[0]) / float(parts[1]) + if '_' in s: s = s.split('_')[0] + return float(s) + return math.isclose(to_float(p), to_float(g), rel_tol=1e-4) + except: + p_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", p) + g_fuzzy = re.sub(r"[^a-z0-9/,\-]", "", g) + return p_fuzzy == g_fuzzy if p_fuzzy else False + +def run_evaluation(target_path): + jsonl_files = [] + if os.path.isdir(target_path): + for root, dirs, files in os.walk(target_path): + for file in files: + if file.endswith(".jsonl") and not file.startswith("eval_voted_"): + jsonl_files.append(os.path.join(root, file)) + else: + jsonl_files = [target_path] + + for file_path in jsonl_files: + print(f">>> 正在评测: {file_path}") + detailed_results = [] + + voted_correct_count = 0 + pass_at_k_count = 0 + total_count = 0 + + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + try: + item = json.loads(line) + except: + continue + + doc = item.get("doc", {}) + ground_truth = str(item.get("target", doc.get("answer", ""))) + + current_nfe = item.get("nfe", 0) + nfe_list.append(current_nfe) + current_svf = item.get("svf_calls", 0) + svf_list.append(current_svf) + + ans_stats = {} + trajectories = item.get("all_trajectories", []) + + has_correct_trajectory = False + + for traj in trajectories: + raw_text = traj.get("resp", "") + score = traj.get("score", 0) + + extracted, _ = extract_answer(raw_text) + if not extracted: continue + + if is_equiv(extracted, ground_truth): + has_correct_trajectory = True + + norm = normalize_math(extracted) + if norm not in ans_stats: + ans_stats[norm] = { + "count": 0, + "max_score": -float('inf'), + "total_weight": 0.0, + "original": extracted + } + + ans_stats[norm]["count"] += 1 + if score > ans_stats[norm]["max_score"]: + ans_stats[norm]["max_score"] = score + + try: + weight = math.exp(score) + except OverflowError: + weight = float('inf') + ans_stats[norm]["total_weight"] += weight + + if has_correct_trajectory: + pass_at_k_count += 1 + + if not ans_stats: + best_pred = "" + else: + sorted_norms = sorted( + ans_stats.keys(), + key=lambda x: (ans_stats[x]["total_weight"], ans_stats[x]["max_score"], ans_stats[x]["count"]), + reverse=True + ) + best_norm = sorted_norms[0] + best_pred = ans_stats[best_norm]["original"] + + is_voted_correct = False + if best_pred and is_equiv(best_pred, ground_truth): + voted_correct_count += 1 + is_voted_correct = True + + total_count += 1 + + detailed_results.append({ + "question": doc.get("problem", "N/A"), + "final_voted_answer": best_pred, + "ground_truth": ground_truth, + "is_voted_correct": is_voted_correct, + "nfe": current_nfe, + "svf_calls": current_svf + }) + + pass_at_1_accuracy = (voted_correct_count / total_count * 100) if total_count > 0 else 0 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- Accuracy: {pass_at_1_accuracy:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + output_name = f"eval_voted_{os.path.basename(file_path).replace('.jsonl', '.json')}" + output_path = os.path.join(os.path.dirname(file_path), output_name) + + final_report = { + "summary": { + "accuracy": f"{pass_at_1_accuracy:.2f}%", + "correct_voted_count": voted_correct_count, + "total": total_count, + "nfe": avg_nfe, + "svf_calls": avg_svf + }, + "details": detailed_results + } + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/mbpp_all.py b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/mbpp_all.py new file mode 100644 index 0000000000000000000000000000000000000000..7dce200195e0e46d44c74008ade6d22492fc3267 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/metrics/mbpp_all.py @@ -0,0 +1,194 @@ +import os +import json +import ast +import glob +import re +import argparse +from typing import Dict, List, Optional, Set, Tuple +import evaluate as hf_evaluate + +RES_PATH = "" + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +def extract_python_code(text: str) -> str: + if not text: return "" + + text = text.replace("<|role_end|>", "").replace("<|endoftext|>", "").replace("<|notification_end|>", "") + + tag_matches = re.findall(r"(.*?)", text, re.DOTALL) + if tag_matches: + for block in tag_matches: + if "def " in block: + text = block + break + else: + text = tag_matches[0] + + if "```python" in text: + blocks = text.split("```python") + for b in blocks[1:]: + code = b.split("```")[0].strip() + if "def " in code: return code + elif "```" in text: + blocks = text.split("```") + for b in blocks[1:]: + code = b.strip() + if "def " in code: return code + + lines = text.split('\n') + cleaned_lines = [] + stop_words = ["Explanation:", "Example:", "Test Case:", "Output:", "Reasoning:"] + for line in lines: + if any(sw in line for sw in stop_words): break + cleaned_lines.append(line) + + return "\n".join(cleaned_lines).strip() + +def normalize_code_for_voting(code: str) -> str: + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (node.body and isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str)): + node.body.pop(0) + return ast.unparse(tree).strip() + except: + return re.sub(r"\s+", "", code) + +def run_evaluation(target_path): + target_path = os.path.abspath(target_path) + + if os.path.isdir(target_path): + search_pattern = os.path.join(target_path, "**/*.jsonl") + jsonl_files = glob.glob(search_pattern, recursive=True) + jsonl_files = [f for f in jsonl_files if not os.path.basename(f).startswith("eval_mbpp_")] + else: + jsonl_files = [target_path] + + if not jsonl_files: + print(f"Error: 在路径 {target_path} 及其子目录下未找到任何 .jsonl 文件。") + return + + try: + code_eval = hf_evaluate.load("code_eval") + except: + print("Error: Could not load code_eval. Ensure 'evaluate' and 'code_eval' are installed.") + return + + for file_path in jsonl_files: + print(f"\n>>> 正在评测 MBPP 文件: {file_path}") + all_candidate_predictions = [] + all_voted_predictions = [] + all_references = [] + detailed_results = [] + nfe_list = [] + svf_list = [] + + with open(file_path, 'r', encoding='utf-8') as f: + for line in f: + if not line.strip(): continue + item = json.loads(line) + + doc = item.get("doc", {}) + test_list = doc.get("test_list", []) + test_setup = doc.get("test_setup_code", "") + full_reference = (test_setup + "\n" + "\n".join(test_list)).strip() + + item_nfe = item.get("nfe", 0) + item_svf = item.get("svf_calls", 0) + nfe_list.append(item_nfe) + svf_list.append(item_svf) + + resps = item.get("resps", []) + trajs = item.get("all_trajectories", []) + + candidate_stats = {} + processed_candidates = [] + + source_data = trajs if trajs else resps + for idx, entry in enumerate(source_data): + raw_text = entry.get("resp", "") if isinstance(entry, dict) else (entry[0] if isinstance(entry, list) else entry) + score = entry.get("score", 0) if isinstance(entry, dict) else 0 + + code = extract_python_code(raw_text) + if not code: continue + + processed_candidates.append(code) + + try: + ast.parse(code) + is_valid = True + except: + is_valid = False + + norm = normalize_code_for_voting(code) + if norm not in candidate_stats: + candidate_stats[norm] = {"count": 0, "valid": is_valid, "code": code, "max_score": -float('inf')} + candidate_stats[norm]["count"] += 1 + candidate_stats[norm]["max_score"] = max(candidate_stats[norm]["max_score"], score) + + if not candidate_stats: + voted_code = "" + else: + sorted_norms = sorted( + candidate_stats.keys(), + key=lambda k: (candidate_stats[k]["valid"], candidate_stats[k]["max_score"], candidate_stats[k]["count"]), + reverse=True + ) + voted_code = candidate_stats[sorted_norms[0]]["code"] + + all_candidate_predictions.append(processed_candidates if processed_candidates else [""]) + all_voted_predictions.append([voted_code]) + all_references.append(full_reference) + + detailed_results.append({ + "task_id": doc.get("task_id", "N/A"), + "voted_code": voted_code, + "nfe": item_nfe, + "svf_calls": item_svf, + "candidates_count": len(processed_candidates) + }) + + if not all_voted_predictions: + continue + + print(f"正在测试代码 (共 {len(all_voted_predictions)} 题)...") + res_voted, details_voted = code_eval.compute(references=all_references, predictions=all_voted_predictions, k=[1]) + res_pk, details_pk = code_eval.compute(references=all_references, predictions=all_candidate_predictions, k=[1]) + + acc_voted = res_voted.get("pass@1", 0.0) * 100 + acc_pk = res_pk.get("pass@1", 0.0) * 100 + avg_nfe = int(round(sum(nfe_list) / len(nfe_list))) if nfe_list else 0 + avg_svf = int(round(sum(svf_list) / len(svf_list))) if svf_list else 0 + + print(f"--- Pass@1: {acc_voted:.2f}% | NFE: {avg_nfe} | SVF: {avg_svf} ---") + + for i, detail in enumerate(detailed_results): + detail["is_voted_correct"] = details_voted.get(i, [[0, {"passed": False}]])[0][1]["passed"] + + file_dir = os.path.dirname(file_path) + base_name = os.path.basename(file_path) + output_name = f"eval_mbpp_{base_name.replace('.jsonl', '.json')}" + output_path = os.path.join(file_dir, output_name) + + final_report = { + "summary": { + "pass_at_1": f"{acc_voted:.2f}%", + "avg_nfe": avg_nfe, + "avg_svf": avg_svf + }, + "details": detailed_results + } + + with open(output_path, 'w', encoding='utf-8') as out_f: + json.dump(final_report, out_f, ensure_ascii=False, indent=4) + print(f"成功保存结果至: {output_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--res_path", type=str, default=RES_PATH) + args = parser.parse_args() + run_evaluation(args.res_path) \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/requirements.txt b/Prism/LLaDA2mini/LLaDA2mini_Prism/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6ae6174b97bc14baecfc1f884ce4881f62558633 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/requirements.txt @@ -0,0 +1,9 @@ +sacrebleu +evaluate +datasets +numpy +pandas +tqdm +regex +sqlitedict +pytablewriter \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_gsm8k.sh b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_gsm8k.sh new file mode 100644 index 0000000000000000000000000000000000000000..02f138b2175e6e4e06a8996901db87df858b5a08 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_gsm8k.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/llada2_gsm8k" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=256 +STEPS=32 +BLOCK=32 +TASK="gsm8k" +TYPE="math" +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=${TYPE},steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_humaneval.sh b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_humaneval.sh new file mode 100644 index 0000000000000000000000000000000000000000..b14ef1e83f6beaed062fb043d61159022833e847 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_humaneval.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/llada2_humaneval" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=512 +STEPS=32 +BLOCK=32 +TASK="humaneval" +TYPE="code" +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=${TYPE},steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --confirm_run_unsafe_code \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_math500.sh b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_math500.sh new file mode 100644 index 0000000000000000000000000000000000000000..92fce436316d8ef1f742c94d4a51a2f2fd2b96a6 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_math500.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/llada2_math500" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=256 +STEPS=32 +BLOCK=32 +TASK="math500" +TYPE="math" +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=${TYPE},steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_mbpp.sh b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_mbpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..af32536440f3f98684bb4b7572a54974d47b9922 --- /dev/null +++ b/Prism/LLaDA2mini/LLaDA2mini_Prism/scripts/run_mbpp.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -e +set -x + +PROJECT_ROOT="" +MODEL_PATH="" +BASE_OUTPUT_PATH="${PROJECT_ROOT}/outputs/llada2_mbpp" + +cd "$PROJECT_ROOT" +export CUDA_VISIBLE_DEVICES=0 +export HF_ENDPOINT=https://hf-mirror.com + +LENGTH=512 +STEPS=32 +BLOCK=32 +TASK="mbpp" +TYPE="code" +NAME="win_0.1-0.6_s2_k4" + +mkdir -p "${BASE_OUTPUT_PATH}/${NAME}" + +accelerate launch evaluation_script.py \ + --model LLaDA2 \ + --tasks ${TASK} \ + --batch_size 1 \ + --model_args "pretrained=${MODEL_PATH},assistant_prefix= " \ + --gen_kwargs "use_hts=True,hts_N=16,final_K=4,hts_survivor_k=2,hts_mode=True,hts_start_pct=0.1,hts_end_pct=0.6,pruning_interval=3,decay_factor=1.8,reward_mode=svf,task_type=${TYPE},steps=${STEPS},block_length=${BLOCK},gen_length=${LENGTH},temperature=0.7,realtime_output=${BASE_OUTPUT_PATH}/${NAME}/res.jsonl" \ + --num_fewshot 0 \ + --confirm_run_unsafe_code \ + --output_path "${BASE_OUTPUT_PATH}/${NAME}" \ No newline at end of file diff --git a/URSA-1.7B/scheduler/__scheduler__.py b/URSA-1.7B/scheduler/__scheduler__.py new file mode 100644 index 0000000000000000000000000000000000000000..c80a5b63b8cfe1030ec7eaaf3b64f96882ba6b50 --- /dev/null +++ b/URSA-1.7B/scheduler/__scheduler__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""Scheduler.""" + +from diffnext.schedulers.scheduling_dfm import KineticOptimalScheduler # noqa diff --git a/URSA-1.7B/scheduler/scheduler_config.json b/URSA-1.7B/scheduler/scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..4f25dfeabc61bd4fd4618a386d90a3287acff2ed --- /dev/null +++ b/URSA-1.7B/scheduler/scheduler_config.json @@ -0,0 +1,7 @@ +{ + "_class_name": "KineticOptimalScheduler", + "alpha": 1.0, + "c": 5, + "eps": 1e-5, + "shift": 4.0 +} diff --git a/URSA-1.7B/tokenizer/tokenizer_config.json b/URSA-1.7B/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..417d038a63fa3de29cfde265caedae14d1a58d92 --- /dev/null +++ b/URSA-1.7B/tokenizer/tokenizer_config.json @@ -0,0 +1,239 @@ +{ + "add_bos_token": false, + "add_prefix_space": false, + "added_tokens_decoder": { + "151643": { + "content": "<|endoftext|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151644": { + "content": "<|im_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151645": { + "content": "<|im_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151646": { + "content": "<|object_ref_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151647": { + "content": "<|object_ref_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151648": { + "content": "<|box_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151649": { + "content": "<|box_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151650": { + "content": "<|quad_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151651": { + "content": "<|quad_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151652": { + "content": "<|vision_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151653": { + "content": "<|vision_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151654": { + "content": "<|vision_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151655": { + "content": "<|image_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151656": { + "content": "<|video_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "151657": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151658": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151659": { + "content": "<|fim_prefix|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151660": { + "content": "<|fim_middle|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151661": { + "content": "<|fim_suffix|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151662": { + "content": "<|fim_pad|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151663": { + "content": "<|repo_name|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151664": { + "content": "<|file_sep|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151665": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151666": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151667": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + }, + "151668": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": false + } + }, + "additional_special_tokens": [ + "<|im_start|>", + "<|im_end|>", + "<|object_ref_start|>", + "<|object_ref_end|>", + "<|box_start|>", + "<|box_end|>", + "<|quad_start|>", + "<|quad_end|>", + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + "<|image_pad|>", + "<|video_pad|>" + ], + "bos_token": null, + "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}", + "clean_up_tokenization_spaces": false, + "eos_token": "<|im_end|>", + "errors": "replace", + "model_max_length": 131072, + "pad_token": "<|endoftext|>", + "split_special_tokens": false, + "tokenizer_class": "Qwen2Tokenizer", + "unk_token": null +} diff --git a/URSA-1.7B/transformer/__transformer__.py b/URSA-1.7B/transformer/__transformer__.py new file mode 100644 index 0000000000000000000000000000000000000000..fac56e3856b5bf914da4fe8a367a86c8b77b4fb4 --- /dev/null +++ b/URSA-1.7B/transformer/__transformer__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""Transformer model.""" + +from diffnext.models.transformers.transformer_ursa import URSATransformer3DModel # noqa diff --git a/URSA-1.7B/transformer/config.json b/URSA-1.7B/transformer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6f97d0f4dbb10e585a5483bfaa105f57d9f8acdf --- /dev/null +++ b/URSA-1.7B/transformer/config.json @@ -0,0 +1,13 @@ +{ + "hidden_size": 2048, + "intermediate_size": 6144, + "max_window_layers": 28, + "num_attention_heads": 16, + "num_key_value_heads": 8, + "num_hidden_layers": 28, + "rope_theta": 1000000, + "vocab_size": 215669, + "lm_vocab_size": 151669, + "lm_head_size": 64000, + "bov_token_id": 151652 +} diff --git a/URSA-1.7B/vae/__vae__.py b/URSA-1.7B/vae/__vae__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab040888aae6960f5e04480664f3900b799614fd --- /dev/null +++ b/URSA-1.7B/vae/__vae__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +"""VAE model.""" + +from diffnext.models.autoencoders.autoencoder_vq_cosmos3d import AutoencoderVQCosmos3D # noqa diff --git a/URSA-1.7B/vae/config.json b/URSA-1.7B/vae/config.json new file mode 100644 index 0000000000000000000000000000000000000000..6700a83e4791744c41e6c22b647b73ce70e16c37 --- /dev/null +++ b/URSA-1.7B/vae/config.json @@ -0,0 +1,22 @@ +{ + "_class_name": "AutoencoderVQCosmos3D", + "_quantizer_name": "FSQuantizer", + "in_channels": 3, + "latent_channels": 256, + "layers_per_block": 2, + "norm_num_groups": 1, + "out_channels": 3, + "sample_size": 1024, + "sample_frames": 49, + "num_vq_embeddings": 64000, + "vq_embed_dim": 6, + "patch_size": 2, + "temporal_stride": 4, + "spatial_stride": 8, + "block_out_channels": [ + 128, + 256, + 512, + 512 + ] +} diff --git a/URSA/.venv_ursa/pyvenv.cfg b/URSA/.venv_ursa/pyvenv.cfg new file mode 100644 index 0000000000000000000000000000000000000000..4a3f8bf6140eacb99c919186e592438ac1bf2fa7 --- /dev/null +++ b/URSA/.venv_ursa/pyvenv.cfg @@ -0,0 +1,5 @@ +home = /usr/bin +include-system-site-packages = false +version = 3.12.3 +executable = /usr/bin/python3.12 +command = /gfs/space/private/fengzl/llm_service_test/GLM-5/.venv/bin/python3.12 -m venv /gfs/space/private/fengzl/world_model/URSA/.venv_ursa diff --git a/URSA/accelerate_configs/deepspeed_zero2.yaml b/URSA/accelerate_configs/deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5ae955d0b868dce3f54519f7c38c09a639221d4d --- /dev/null +++ b/URSA/accelerate_configs/deepspeed_zero2.yaml @@ -0,0 +1,12 @@ +distributed_type: DEEPSPEED +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_clipping: 0.0 + zero_stage: 3 #2 + offload_optimizer_device: cpu # Moves optimizer states to CPU RAM + offload_param_device: cpu # Moves model parameters to CPU RAM + zero3_init_flag: true # Initializes the model directly across GPUs to save CPU RAM + zero3_save_16bit_model: true # Consolidates weights into a single file when saving checkpoints +num_machines: 1 +num_processes: 8 +machine_rank: 0 diff --git a/URSA/assets/sample_image.jpg b/URSA/assets/sample_image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1200b735417a797eaa3df888273540d97fd83b0b Binary files /dev/null and b/URSA/assets/sample_image.jpg differ diff --git a/URSA/configs/distill_dimo.yaml b/URSA/configs/distill_dimo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc314d3833a6f35c64b88224971ac9ec92cb47d8 --- /dev/null +++ b/URSA/configs/distill_dimo.yaml @@ -0,0 +1,158 @@ +# ============================================================================ +# URSA one-step distillation — DiMO-style distributed training config +# ============================================================================ +# Verified native inference regime (from A/B testing — ground truth): +# height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. +# no_cfg (guidance_scale=1) does NOT produce valid output. +# All defaults below align to this verified regime. +# +# Launch (8-GPU, single node): +# +# accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ +# --machine_rank 0 --num_machines 1 --num_processes 8 \ +# scripts/train_distill_dimo.py \ +# config="./configs/distill_dimo.yaml" \ +# experiment.output_dir="./experiments/distill_dimo" \ +# distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \ +# distill.prompt_source="/data/Koala_36M_*.csv" +# +# Smoke test (1 GPU, 50 steps — save student checkpoint): +# +# accelerate launch --num_processes 1 \ +# scripts/train_distill_dimo.py \ +# config="./configs/distill_dimo.yaml" \ +# experiment.output_dir="./experiments/smoke" \ +# distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \ +# distill.prompt_source="prompts.txt" \ +# training.max_train_steps=50 \ +# experiment.save_every=50 +# +# Load student for 1-step inference (must use CFG=7, native geometry): +# +# pipe = URSAPipeline.from_pretrained("/path/to/URSA-1.7B-IBQ1024") +# state = torch.load("experiments/distill_dimo/checkpoints/final/student.pt") +# pipe.transformer.load_state_dict(state, strict=True) +# frames = pipe(prompt="...", num_inference_steps=1, +# height=320, width=512, num_frames=49, +# guidance_scale=7).frames +# ============================================================================ + +# ── Experiment bookkeeping ─────────────────────────────────────────────────── +experiment: + name: distill_dimo + output_dir: ./experiments/distill_dimo + log_every: 10 + save_every: 100 + resume_iter: 0 # set to step number to resume + +# ── Training (framework-level) ─────────────────────────────────────────────── +training: + seed: 42 + mixed_precision: bf16 # bf16 | fp16 | fp32 + max_train_steps: 10000 + gradient_accumulation_steps: 1 # Two-backward; keep =1 for distillation + +# ── Distillation hyperparameters ───────────────────────────────────────────── +distill: + # ---- Paths ---------------------------------------------------------------- + teacher_ckpt: /gfs/space/private/fengzl/World_Model/URSA-1.7B + prompt_source: /gfs/space/private/fengzl/World_Model/Koala-36M-v1 # glob, dir, .txt, or comma-list + + # ---- Video geometry (verified native: 320×512×49) ------------------------- + num_frames: 49 + height: 320 + width: 512 + max_prompt_length: 320 + + # ---- Data ----------------------------------------------------------------- + batch_size_per_gpu: 1 # effective global batch = batch_size_per_gpu × 8 GPUs + + # # ---- Loss weights --------------------------------------------------------- + # lambda_kd: 0.5 # KL(z_T || z_S) weight + # lambda_pg: 1.0 # REINFORCE policy gradient weight + # lambda_ent: 0.01 # entropy bonus (λ_ent_eff × H) — set 0 for DiMO orig + # tau: 1.0 # student sampling temperature + # tau_kd: 1.0 # KD / Jeffrey softmax temperature + + # # ---- Teacher CFG (aligned to verified working regime: CFG=7) --------------- + # # A/B testing confirmed: guidance_scale=1 (no_cfg) does NOT produce valid + # # output for this URSA checkpoint. The teacher KD target must use CFG=7. + # enable_teacher_cfg: true + # teacher_cfg_scale: 7.0 # s in z_guided = z_uncond + s*(z_cond-z_uncond) + # # Verified: CFG=7 is the official working value. + # teacher_cfg_prob: 1.0 # max fraction of samples using guided target + # teacher_cfg_warmup_steps: 2000 # linear warmup 0→teacher_cfg_prob + # teacher_cfg_trunc: 0.9 # when t≥trunc, scale falls to 1 (no guide) + # lambda_kd_uncond: 0.3 # weight for uncond-branch KD loss + # reward_use_guided: false # [RISKY] use guided logits for reward signal + + # # ---- DiMO extensions ------------------------------------------------------- + # fake_rounds: 1 # aux updates per student update (DiMO=2; try 2) + # use_surrogate_grad: false + # lambda_surr: 1.0 + + # ---- Loss weights --------------------------------------------------------- + lambda_kd: 1.0 # KL(z_T || z_S) weight (基础知识蒸馏权重,保持不变) + lambda_pg: 1.0 # [重用] 现在代表 lambda_bridge,控制 MSE 伪梯度注入的强度 + lambda_ent: 0.0 # [已废弃] 强化学习的熵奖励已彻底删除,设为 0.0 + tau: 1.0 # student sampling temperature + tau_kd: 1.0 # KD softmax temperature + + # ---- Teacher CFG (aligned to verified working regime: CFG=7) --------------- + enable_teacher_cfg: true + teacher_cfg_scale: 7.0 + teacher_cfg_prob: 1.0 + teacher_cfg_warmup_steps: 1000 + teacher_cfg_trunc: 0.9 + lambda_kd_uncond: 0.3 + # reward_use_guided: false <-- [请直接删除这行] 因为 Reward 计算已被移除 + + # ---- DiMO extensions ------------------------------------------------------- + fake_rounds: 2 #1 # Aux 拟合假 token 的迭代次数。如果发现 Aux 算出的 bridge_loss 降不下去,可以尝试改为 2 + use_surrogate_grad: false + lambda_surr: 1.0 + + # ---- Stability ------------------------------------------------------------- + t_curriculum_steps: 10000 # curriculum steps before uniform-t sampling + p_init_mix_ratio: 0.2 # fraction of batch from corrupted x_hat_prev + p_mix_corrupt_frac: 0.2 # token corruption rate in p_init mixing + collapse_warn_frac: 0.2 # warn if tok_entropy < frac × initial entropy + + # ---- Aux initialisation --------------------------------------------------- + aux_noise_std: 1.0e-5 # tiny noise added to aux weights at init to break + # symmetry; set 0.0 to keep aux == student exactly + + # ---- Gradient clipping ---------------------------------------------------- + grad_clip: 1.0 + +# ── Student optimizer ──────────────────────────────────────────────────────── +optimizer_student: + target: torch.optim.AdamW + params: + lr: 1.0e-5 + betas: [0.9, 0.95] + weight_decay: 0.01 + +# ── Aux optimizer ──────────────────────────────────────────────────────────── +optimizer_aux: + target: torch.optim.AdamW + params: + lr: 1.0e-5 + betas: [0.9, 0.95] + weight_decay: 0.01 + +# ── LR scheduler (cosine, shared warmup/decay params for both opts) ────────── +lr_scheduler: + target: diffnext.engine.lr_scheduler.CosineLR + params: + lr_max: ${optimizer_student.params.lr} + lr_min: 1.0e-6 + max_steps: ${training.max_train_steps} + warmup_steps: 500 + +# ── Prompt DataLoader ───────────────────────────────────────────────────────── +prompt_dataloader: + shuffle_files: true + shuffle_buffer: 50000 # in-memory shuffle buffer per shard; reduce if OOM + num_workers: 4 # CPU workers (no CUDA in workers) + caption_field: caption # CSV column name (Koala default) diff --git a/URSA/configs/onestep_dimo.yaml b/URSA/configs/onestep_dimo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..83e917530b7b0417fc4eb9a43025dd79b8818440 --- /dev/null +++ b/URSA/configs/onestep_dimo.yaml @@ -0,0 +1,111 @@ +# ============================================================================ +# URSA one-step distillation — DiMO-style training configuration +# ============================================================================ +# Reference: train_onestep_ursa_dimo.py +# +# DiMO hyperparameter comparison (Meissonic vs. our URSA defaults) +# --------------------------------------------------------------- +# Param DiMO (Meissonic) URSA (this config) Risk / Note +# ───────────────────────────────────────────────────────────────────────── +# guidance_scale (CFG) 3.0 (true_cfg) 3.0 (teacher_cfg) ✅ aligned +# fake_rounds 2 1 ⚠ try 2 for aux stability +# fixed_ratio 0.5 (mask ratio) — N/A (different domain) +# distil_loss_type surrogate MSE optional surrogate ✅ toggle via use_surrogate_grad +# noise_emb_perturb True — ℹ️ not needed for VQ-based model +# cfg_prob 1.0 teacher_cfg_prob=1.0 ✅ aligned +# lambda_ent 0.0 (no ent reg) 0.01 ℹ️ our addition for stability +# ============================================================================ + +# ── Paths ──────────────────────────────────────────────────────────────────── +teacher_ckpt: "/path/to/URSA" +prompt_file: "prompts.txt" +out_dir: "./outputs/dimo" + +# ── Video geometry ─────────────────────────────────────────────────────────── +num_frames: 17 +height: 256 +width: 256 +max_prompt_length: 320 + +# ── Training ───────────────────────────────────────────────────────────────── +batch_size: 2 # reduce to 1 if enable_teacher_cfg uses too much VRAM +num_steps: 10000 +lr_student: 1.0e-5 +lr_aux: 1.0e-5 +weight_decay: 0.01 +grad_clip: 1.0 +mixed_precision: "bf16" +seed: 42 +log_every: 50 +save_every: 1000 + +# ── Loss weights ───────────────────────────────────────────────────────────── +lambda_pg: 1.0 +lambda_kd: 0.5 +lambda_ent: 0.01 # entropy regularisation (0 → DiMO original; 0.01 → our default) +tau: 1.0 # student sampling temperature +tau_kd: 1.0 # KD softmax temperature + +# ── Teacher CFG (DiMO true_cfg style) ──────────────────────────────────────── +# Set enable_teacher_cfg: false to revert to the prior single-branch behavior. +# All other params in this block are ignored when enable_teacher_cfg=false. +enable_teacher_cfg: true + +teacher_cfg_scale: 3.0 # s in z_guided = z_uncond + s*(z_cond - z_uncond) + # Matches DiMO true_cfg=3.0 + +teacher_cfg_prob: 1.0 # Probability of using guided target per batch (after warmup). + # 1.0 = always guided (DiMO default). + +teacher_cfg_warmup_steps: 2000 + # Ramp teacher_cfg_prob from 0 → teacher_cfg_prob over this many + # steps. Prevents instability at the start of training. + +teacher_cfg_trunc: 0.9 # When t >= trunc, CFG scale falls to 1 (no guidance at high noise). + # Mirrors DiMO's guidance_trunc parameter. + +lambda_kd_uncond: 0.3 # Weight for uncond-branch KD loss. + # Keeps the student uncond-capable for eval-time CFG. + +reward_use_guided: false # [RISKY] Use guided teacher logits for REINFORCE reward. + # Default false: use non-guided cond (more stable). + +# ── Eval / inference CFG ───────────────────────────────────────────────────── +eval_cfg_scale: 3.0 # guidance_scale used during evaluation +use_cfg_eval: false # Run eval with inference-time CFG (2× forward) + +# ── DiMO extensions ────────────────────────────────────────────────────────── +use_surrogate_grad: false # DiMO surrogate MSE trick (zero-variance alternative to REINFORCE) +lambda_surr: 1.0 +fake_rounds: 1 # Aux updates per generator update (DiMO uses 2; try 2 for aux stability) + +# ── Stability ───────────────────────────────────────────────────────────────── +t_curriculum_steps: 10000 # Steps to use t-curriculum (biases t toward larger values) +p_mix_corrupt_frac: 0.2 # Fraction of tokens to corrupt in p_init mixing +p_init_mix_ratio: 0.2 # Fraction of batch drawn from corrupted x_hat_prev +collapse_warn_frac: 0.2 # Warn if tok_hist_entropy drops below this fraction of initial + +# ── Debug ──────────────────────────────────────────────────────────────────── +dry_run: false # Run 1 step, print diagnostics, exit +debug_dump: 0 # Dump token histogram + x_hat every N steps (0=off) + +# ── Recommended quick-start commands ───────────────────────────────────────── +# # Smoke test (CFG enabled): +# python scripts/train_onestep_ursa_dimo.py \ +# --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \ +# --enable_teacher_cfg --teacher_cfg_scale 3.0 \ +# --num_frames 17 --height 256 --width 256 --dry_run +# +# # Full training (DiMO-aligned): +# python scripts/train_onestep_ursa_dimo.py \ +# --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \ +# --enable_teacher_cfg --teacher_cfg_scale 3.0 \ +# --batch_size 2 --num_steps 10000 --fake_rounds 2 \ +# --out_dir ./outputs/dimo_cfg +# +# # Eval (compare 3 student modes vs teacher): +# python scripts/eval_onestep_ursa.py \ +# --teacher_ckpt /path/to/URSA \ +# --student_ckpt ./outputs/dimo_cfg/final/student.pt \ +# --modes no_cfg cfg baked --eval_cfg_scale 3.0 \ +# --out_dir ./outputs/eval diff --git a/URSA/configs/ursa_0.6b_fsq320.yaml b/URSA/configs/ursa_0.6b_fsq320.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b890e05c5496dca4cc7aff614b9452e5f540f8dd --- /dev/null +++ b/URSA/configs/ursa_0.6b_fsq320.yaml @@ -0,0 +1,62 @@ +wandb: + run_id: null + +experiment: + project: ursa_0.6b_fsq320 + log_every: 20 + save_every: 5000 + resume_from_checkpoint: latest + +model: + name: "transformer" + gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt + async_timestep: true + tokenizer: + params: + max_length: 320 + truncation: true + padding_side: left + padding: max_length + +pipeline: + target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline + paths: + pretrained_path: /path/to/URSA-0.6B-FSQ320 + module_dict: + vae: ${pipeline.paths.pretrained_path}/vae + scheduler: ${pipeline.paths.pretrained_path}/scheduler + tokenizer: ${pipeline.paths.pretrained_path}/tokenizer + model_index: ${pipeline.paths.pretrained_path}/model_index.json + +optimizer: + target: torch.optim.AdamW + param_groups: false + params: + lr: 0.00003 + betas: [0.9, 0.95] + weight_decay: 0.05 + fused: true + +lr_scheduler: + target: diffnext.engine.lr_scheduler.CosineLR + params: + lr_max: ${optimizer.params.lr} + lr_min: 0.00001 + max_steps: ${training.max_train_steps} + warmup_steps: 500 + +train_dataloader: + target: diffnext.data.flex_loaders.FeatureDataLoader + params: + dataset: /path/to/fsq320_dataset + batch_size: ${training.batch_size} + seed: ${training.seed} + num_workers: 4 + shuffle: true + +training: + gradient_accumulation_steps: 1 + batch_size: 1 # * 256 = 256 + max_train_steps: 20000 + seed: 1337 + mixed_precision: bf16 diff --git a/URSA/configs/ursa_0.6b_ibq1024.yaml b/URSA/configs/ursa_0.6b_ibq1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..352cb87b27ec6164a2c1d714be10068134ae0b37 --- /dev/null +++ b/URSA/configs/ursa_0.6b_ibq1024.yaml @@ -0,0 +1,62 @@ +wandb: + run_id: null + +experiment: + project: ursa_0.6b_ibq1024 + log_every: 20 + save_every: 5000 + resume_from_checkpoint: latest + +model: + name: "transformer" + gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt + async_timestep: false + tokenizer: + params: + max_length: 320 + truncation: true + padding_side: left + padding: max_length + +pipeline: + target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline + paths: + pretrained_path: /path/to/URSA-0.6B-IBQ1024 + module_dict: + vae: ${pipeline.paths.pretrained_path}/vae + scheduler: ${pipeline.paths.pretrained_path}/scheduler + tokenizer: ${pipeline.paths.pretrained_path}/tokenizer + model_index: ${pipeline.paths.pretrained_path}/model_index.json + +optimizer: + target: torch.optim.AdamW + param_groups: false + params: + lr: 0.00003 + betas: [0.9, 0.95] + weight_decay: 0.05 + fused: true + +lr_scheduler: + target: diffnext.engine.lr_scheduler.CosineLR + params: + lr_max: ${optimizer.params.lr} + lr_min: 0.00001 + max_steps: ${training.max_train_steps} + warmup_steps: 500 + +train_dataloader: + target: diffnext.data.flex_loaders.FeatureDataLoader + params: + dataset: /path/to/ibq1024_dataset + batch_size: ${training.batch_size} + seed: ${training.seed} + num_workers: 4 + shuffle: true + +training: + gradient_accumulation_steps: 1 + batch_size: 1 # * 512 = 512 + max_train_steps: 120000 + seed: 1337 + mixed_precision: bf16 diff --git a/URSA/configs/ursa_1.7b_fsq320.yaml b/URSA/configs/ursa_1.7b_fsq320.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2e5f945f904913b87b1809aff61e8beb38803ea --- /dev/null +++ b/URSA/configs/ursa_1.7b_fsq320.yaml @@ -0,0 +1,62 @@ +wandb: + run_id: null + +experiment: + project: ursa_1.7b_fsq320 + log_every: 20 + save_every: 5000 + resume_from_checkpoint: latest + +model: + name: "transformer" + gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt + async_timestep: true + tokenizer: + params: + max_length: 320 + truncation: true + padding_side: left + padding: max_length + +pipeline: + target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline + paths: + pretrained_path: /path/to/URSA-1.7B-FSQ320 + module_dict: + vae: ${pipeline.paths.pretrained_path}/vae + scheduler: ${pipeline.paths.pretrained_path}/scheduler + tokenizer: ${pipeline.paths.pretrained_path}/tokenizer + model_index: ${pipeline.paths.pretrained_path}/model_index.json + +optimizer: + target: torch.optim.AdamW + param_groups: false + params: + lr: 0.00003 + betas: [0.9, 0.95] + weight_decay: 0.05 + fused: true + +lr_scheduler: + target: diffnext.engine.lr_scheduler.CosineLR + params: + lr_max: ${optimizer.params.lr} + lr_min: 0.00001 + max_steps: ${training.max_train_steps} + warmup_steps: 500 + +train_dataloader: + target: diffnext.data.flex_loaders.FeatureDataLoader + params: + dataset: /path/to/fsq320_dataset + batch_size: ${training.batch_size} + seed: ${training.seed} + num_workers: 4 + shuffle: true + +training: + gradient_accumulation_steps: 1 + batch_size: 1 # * 256 = 256 + max_train_steps: 20000 + seed: 1337 + mixed_precision: bf16 diff --git a/URSA/configs/ursa_1.7b_ibq1024.yaml b/URSA/configs/ursa_1.7b_ibq1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d2e122599f087447f137c67fbe799ec0a4efb0d --- /dev/null +++ b/URSA/configs/ursa_1.7b_ibq1024.yaml @@ -0,0 +1,62 @@ +wandb: + run_id: null + +experiment: + project: ursa_1.7b_ibq1024 + log_every: 20 + save_every: 5000 + resume_from_checkpoint: latest + +model: + name: "transformer" + gradient_checkpointing: 2 # 1: +mlp_ckpt 2: +qkv_ckpt 3: +layer_ckpt + async_timestep: false + tokenizer: + params: + max_length: 320 + truncation: true + padding_side: left + padding: max_length + +pipeline: + target: diffnext.pipelines.ursa.pipeline_train.URSATrainPipeline + paths: + pretrained_path: /path/to/URSA-1.7B-IBQ1024 + module_dict: + vae: ${pipeline.paths.pretrained_path}/vae + scheduler: ${pipeline.paths.pretrained_path}/scheduler + tokenizer: ${pipeline.paths.pretrained_path}/tokenizer + model_index: ${pipeline.paths.pretrained_path}/model_index.json + +optimizer: + target: torch.optim.AdamW + param_groups: false + params: + lr: 0.00003 + betas: [0.9, 0.95] + weight_decay: 0.05 + fused: true + +lr_scheduler: + target: diffnext.engine.lr_scheduler.CosineLR + params: + lr_max: ${optimizer.params.lr} + lr_min: 0.00001 + max_steps: ${training.max_train_steps} + warmup_steps: 500 + +train_dataloader: + target: diffnext.data.flex_loaders.FeatureDataLoader + params: + dataset: /path/to/ibq1024_dataset + batch_size: ${training.batch_size} + seed: ${training.seed} + num_workers: 4 + shuffle: true + +training: + gradient_accumulation_steps: 1 + batch_size: 1 # * 512 = 512 + max_train_steps: 120000 + seed: 1337 + mixed_precision: bf16 diff --git a/URSA/diffnext/__init__.py b/URSA/diffnext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4111073bdbc3d91c74c4a15e95caab0c862ac142 --- /dev/null +++ b/URSA/diffnext/__init__.py @@ -0,0 +1,16 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------ +"""DiffNext: A diffusers based library for autoregressive diffusion models.""" diff --git a/URSA/diffnext/image_processor.py b/URSA/diffnext/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..2420e79d74dd34897b68ce703770b38699f34982 --- /dev/null +++ b/URSA/diffnext/image_processor.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------ +"""Image processor.""" + +from typing import List, Union + +import numpy as np +import PIL.Image +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin + + +class VaeImageProcessor(ConfigMixin): + """Image processor for VAE.""" + + def postprocess( + self, image: torch.Tensor, output_type: str = "pil" + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: + """Postprocess the image output from tensor. + + Args: + image (torch.Tensor): + The image tensor. + output_type (str, *optional*, defaults to `pil`): + The output image type, can be one of `pil`, `np`, `pt`, `latent`. + + Returns: + Union[PIL.Image.Image, np.ndarray, torch.Tensor]: The postprocessed image. + """ + if output_type == "latent" or output_type == "pt": + return image + image = self.pt_to_numpy(image) + if output_type == "np": + return image + if output_type == "pil": + return self.numpy_to_pil(image) + return image + + @staticmethod + @torch.no_grad() + def decode_latents(vae: nn.Module, latents: torch.Tensor, vae_batch_size=1) -> torch.Tensor: + """Decode VAE latents. + + Args: + vae (torch.nn.Module): + The VAE model. + latents (torch.Tensor): + The input latents. + vae_batch_size (int, *optional*, defaults to 1) + The maximum images in a batch to decode. + + Returns: + torch.Tensor: The output tensor. + + """ + x, batch_size = vae.unscale_(latents), latents.size(0) + sizes, splits = [vae_batch_size] * (batch_size // vae_batch_size), [] + sizes += [batch_size - sum(sizes)] if sum(sizes) != batch_size else [] + for x_split in x.split(sizes) if len(sizes) > 1 else [x]: + splits.append(vae.decode(x_split).sample) + return torch.cat(splits) if len(splits) > 1 else splits[0] + + @staticmethod + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: + """Convert images from a torch tensor to a numpy array. + + Args: + images (torch.Tensor): + The image tensor. + + Returns: + np.ndarry: The image array. + """ + x = images.permute(0, 2, 3, 4, 1) if images.dim() == 5 else images.permute(0, 2, 3, 1) + return x.mul(127.5).add_(127.5).clamp(0, 255).byte().cpu().numpy() + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: + """Convert images from a numpy array to a list of PIL objects. + + Args: + images (np.ndarray): + The image array. + + Returns: + List[PIL.Image.Image]: A list of PIL images. + """ + images = images[None, ...] if images.ndim == 3 else images + images = images.reshape((-1,) + images.shape[2:]) if images.ndim == 5 else images + return [PIL.Image.fromarray(image) for image in images] diff --git a/URSA/diffnext/version.py b/URSA/diffnext/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a15a8dc6128103bf9f7e6896519ee17284c41a52 --- /dev/null +++ b/URSA/diffnext/version.py @@ -0,0 +1,3 @@ +version = "0.3.0a0" +git_version = "27f3f7577bcc71c2f08a8a069e4ef4ab70cf8bd7" +__version__ = version diff --git a/URSA/docs/evaluation.md b/URSA/docs/evaluation.md new file mode 100644 index 0000000000000000000000000000000000000000..aa8e716bc0a85239a16e73bab3e487ff94a35721 --- /dev/null +++ b/URSA/docs/evaluation.md @@ -0,0 +1,49 @@ +# Evaluations + +## GenEval + +### 1. Sample prompt images +```bash +python ./evaluations/geneval/sample.py \ +--height 1024 --width 1024 \ +--guidance_scale 7 --num_inference_steps 25 \ +--ckpt /path/to/URSA-1.7B-IBQ1024 \ +--prompt_size 4 --outdir ./samples/geneval/URSA-1.7B-IBQ1024 +``` + +### 2. Evaluation +=./samples/geneval/URSA-1.7B-IBQ1024 + +Please refer [GenEval](https://github.com/djghosh13/geneval?tab=readme-ov-file#evaluation) evaluation guide. + +## DPG-Bench + +### 1. Sample prompt images +```bash +python evaluations/dpgbench/sample.py \ +--height 1024 --width 1024 \ +--guidance_scale 7 --num_inference_steps 25 \ +--ckpt ./checkpoints/URSA-1.7B-IBQ1024 \ +--prompt_size 4 --outdir samples/dpgbench/URSA-1.7B-IBQ1024 +``` + +### 2. Evaluation +=./samples/dpgbench/URSA-1.7B-IBQ1024 + +Please refer [DPG-Bench](https://github.com/TencentQQGYLab/ELLA?tab=readme-ov-file#-dpg-bench) evaluation guide. + +## VBench + +### 1. Sample prompt videos +```bash +python evaluations/vbench/sample.py \ +--num_frames 49 --height 320 --width 512 \ +--guidance_scale 7 --num_inference_steps 50 --motion_score 9 \ +--ckpt ./checkpoints/URSA-1.7B-FSQ320 \ +--prompt_size 1 --outdir ./samples/vbench/URSA-1.7B-FSQ320 +``` + +### 2. Evaluation +=./samples/vbench/URSA-1.7B-FSQ320 + +Please refer [VBench](https://github.com/Vchitect/VBench?tab=readme-ov-file#evaluation-on-the-standard-prompt-suite-of-vbench) evaluation guide. diff --git a/URSA/docs/training.md b/URSA/docs/training.md new file mode 100644 index 0000000000000000000000000000000000000000..22e2c5427da865daa0c6584d4e76366ab1461b3e --- /dev/null +++ b/URSA/docs/training.md @@ -0,0 +1,104 @@ +# Training Guide +This guide provides simple snippets to train diffnext models. + +# 1. Build VQVAE cache +To optimize training workflow, we preprocess images or videos into VQVAE latents. + +## Requirements: +```bash +pip install protobuf==3.20.3 codewithgpu decord +``` + +## Build T2I cache +Following snippet can be used to cache image latents: + +```python +import os, codewithgpu, torch, PIL.Image, numpy as np +from diffnext.models.autoencoders.autoencoder_vq import AutoencoderVQ + +device, dtype = torch.device("cuda"), torch.float16 +vae = AutoencoderVQ.from_pretrained("/path/to/BAAI/URSA-1.7B-IBQ1024/vae") +vae = vae.to(device=device, dtype=dtype).eval() + +features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"]} +os.makedirs("./datasets/ibq1024_dataset", exist_ok=True) +writer = codewithgpu.RecordWriter("./datasets/ibq1024_dataset", features) + +img = PIL.Image.open("./assets/sample_image.jpg") +x = torch.as_tensor(np.array(img)[None, ...].transpose(0, 3, 1, 2)).to(device).to(dtype) +with torch.no_grad(): + x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.unsqueeze(1).cpu().numpy()[0] +example = {"caption": "long caption", "text": "short text"} +# Ensure enough examples for codewithgou distributed dataset. +[writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)] +writer.close() +``` + +## Build T2V cache +Following snippet can be used to cache video latents: + +```python +import os, codewithgpu, torch, decord, numpy as np +from diffnext.models.autoencoders.autoencoder_vq_cosmos3d import AutoencoderVQCosmos3D + +device, dtype = torch.device("cuda"), torch.float16 +vae = AutoencoderVQCosmos3D.from_pretrained("/path/to/URSA-1.7B-FSQ320/vae") +vae = vae.to(device=device, dtype=dtype).eval() + +features = {"codes": "bytes", "caption": "string", "text": "string", "shape": ["int64"], "flow": "float64"} +os.makedirs("./datasets/fsq320_dataset", exist_ok=True) +writer = codewithgpu.RecordWriter("./datasets/fsq320_dataset", features) + +resize, crop_size, frame_ids = 320, (320, 512), list(range(0, 97, 2)) +vid = decord.VideoReader("./assets/sample_video.mp4") +h, w = vid[0].shape[:2] +scale = float(resize) / float(min(h, w)) +size = int(h * scale + 0.5), int(w * scale + 0.5) +y, x = (size[0] - crop_size[0]) // 2, (size[1] - crop_size[1]) // 2 +vid = decord.VideoReader("./assets/sample_video.mp4", height=size[0], width=size[1]) +vid = vid.get_batch(frame_ids).asnumpy() +vid = vid[:, y : y + crop_size[0], x : x + crop_size[1]] +x = torch.as_tensor(vid[None, ...].transpose((0, 4, 1, 2, 3))).to(device).to(dtype) +with torch.no_grad(): + x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.cpu().numpy()[0] +example = {"caption": "long caption", "text": "short text", "flow": 9} +# Ensure enough examples for codewithgou distributed dataset. +[writer.write({"shape": x.shape, "codes": x.tobytes(), **example}) for _ in range(16)] +writer.close() +``` + +# 2. Train models + +## Train T2I model +Following snippet provides simple T2I training arguments: + +```bash +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ +--machine_rank 0 --num_machines 1 --num_processes 8 \ +scripts/train.py \ +config="./configs/ursa_1.7b_ibq1024.yaml" \ +experiment.name="ursa_1.7b_ibq1024" \ +experiment.output_dir="./experiments/ursa_1.7b_ibq1024" \ +pipeline.paths.pretrained_path="/path/to/URSA-1.7B-IBQ1024" \ +train_dataloader.params.dataset="./datasets/ibq1024_dataset" \ +model.gradient_checkpointing=3 \ +training.batch_size=4 \ +trainin.gradient_accumulation_steps=16 +``` + +## Train T2V model +Following snippet provides simple T2V training arguments: + +```bash +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ +--machine_rank 0 --num_machines 1 --num_processes 8 \ +scripts/train.py \ +config="./configs/ursa_1.7b_fsq320.yaml" \ +experiment.name="ursa_1.7b_fsq320" \ +experiment.output_dir="./experiments/ursa_1.7b_fsq320" \ +pipeline.paths.pretrained_path="/path/to/URSA-1.7B-FSQ320" \ +train_dataloader.params.dataset="./datasets/fsq320_dataset" \ +model.gradient_checkpointing=3 \ +training.batch_size=1 \ +trainin.gradient_accumulation_steps=32 +``` diff --git a/URSA/her/ursa.jpg b/URSA/her/ursa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d548214a59b1515136cc6b5b9d782de0f5622134 Binary files /dev/null and b/URSA/her/ursa.jpg differ diff --git a/URSA/outputs/eval/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..40bed89fbb5a7b74711513afad90c5f3258fc465 Binary files /dev/null and b/URSA/outputs/eval/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1b67dad36dd82c670d7500c53f15ab62348f67c9 Binary files /dev/null and b/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 b/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1b67dad36dd82c670d7500c53f15ab62348f67c9 Binary files /dev/null and b/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 b/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5973bc9ab99563210fe0bfda716ec3b409708b34 Binary files /dev/null and b/URSA/outputs/eval/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..44cb45493285979dc0389436b95a20a89458d10a Binary files /dev/null and b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a805a9c0caa0993dfd355769af431a3a3b9ffdc7 Binary files /dev/null and b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..da966224b985e6600653a707345e4480cd1a86c1 Binary files /dev/null and b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3e40c724669e7fff234cd72619dc4f04c6eb82a8 Binary files /dev/null and b/URSA/outputs/eval/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fa9efb745d60dd3a4201563e8decbe4581cde043 Binary files /dev/null and b/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 b/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fa9efb745d60dd3a4201563e8decbe4581cde043 Binary files /dev/null and b/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 b/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3f84bcdb157e09304149da55e7201e43b0053710 Binary files /dev/null and b/URSA/outputs/eval/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..def87b243cfa644c36eaf1b2211ae0e02892c74a Binary files /dev/null and b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c8ae73bfb2dc3e2acc8119936bf0d3933d8a4d5e Binary files /dev/null and b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..def87b243cfa644c36eaf1b2211ae0e02892c74a Binary files /dev/null and b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6e6fffdf8717adfffca12091495dc648dfb82838 Binary files /dev/null and b/URSA/outputs/eval/01_s0_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6a8a7b7e283359c0482b98c3f2b921c0a01c6257 Binary files /dev/null and b/URSA/outputs/eval/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 b/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5fe87ba29247f32f2225b6774001ecc1980a474d Binary files /dev/null and b/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0373af6d902eade5297861323c3bc2a04dad2704 Binary files /dev/null and b/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..71a17e50cd81415983026a329418dc0a454143b3 Binary files /dev/null and b/URSA/outputs/eval/01_s2_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 b/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..51ac5e8b29ca65fd79ce8e6cd0c3e2173e38e939 Binary files /dev/null and b/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d1333c4a134cf2fdc7fad54e433978abb8a81a92 Binary files /dev/null and b/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..23257068f980949d52c7778fcd92c3cde6def3f5 Binary files /dev/null and b/URSA/outputs/eval/01_s3_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 b/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3fa0532fea21f7127847611db475c01ecb6aa43e Binary files /dev/null and b/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 b/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..10918097015fbffa9c99bdfff273e234558b22c6 Binary files /dev/null and b/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..acbf74d14ac6dda8157e8601dca967918b5a57f2 Binary files /dev/null and b/URSA/outputs/eval/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 b/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a5e4ab79953a1a25e75ec174ada03d5fed5547d9 Binary files /dev/null and b/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 b/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a5e4ab79953a1a25e75ec174ada03d5fed5547d9 Binary files /dev/null and b/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0a7204a5a3611c33a93788d4821d68216896511d Binary files /dev/null and b/URSA/outputs/eval/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 b/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..40fbdceb81f48513265cffa10769a72e8ff0e4c4 Binary files /dev/null and b/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 b/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1b6f4fee189d4689d7e53c32673699924a15cb10 Binary files /dev/null and b/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 b/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a639bfd04f2ad59046b61af5234463f3c0560dca Binary files /dev/null and b/URSA/outputs/eval/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 b/URSA/outputs/eval/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9f8b9bc479befa9586803eb1d87b4c408c045713 Binary files /dev/null and b/URSA/outputs/eval/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 b/URSA/outputs/eval/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2e22ce12104b0905aadc270219919fb5abe295a3 Binary files /dev/null and b/URSA/outputs/eval/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 b/URSA/outputs/eval/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6ea9bd067bd256aec6e94b4b79cdeefdb02486fd Binary files /dev/null and b/URSA/outputs/eval/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 b/URSA/outputs/eval/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7e5debf0a7fe430a313d1465ef66956c7c6f8576 Binary files /dev/null and b/URSA/outputs/eval/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 b/URSA/outputs/eval/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..35ddf19ff6357e309a11a88dfee96b9969b5ebc0 Binary files /dev/null and b/URSA/outputs/eval/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 b/URSA/outputs/eval/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fc4b655271907c623f8c663fda4d59bcd5d998da Binary files /dev/null and b/URSA/outputs/eval/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 b/URSA/outputs/eval/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4eb94efb83562637891f33dcd7bb74fec27a3752 Binary files /dev/null and b/URSA/outputs/eval/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 b/URSA/outputs/eval/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8bbb48a8098cbb66084a9f30d576de5a699d7122 Binary files /dev/null and b/URSA/outputs/eval/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 b/URSA/outputs/eval/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8bbb48a8098cbb66084a9f30d576de5a699d7122 Binary files /dev/null and b/URSA/outputs/eval/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 b/URSA/outputs/eval/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..13057a968821ab82abafe7671d1e211a3b543357 Binary files /dev/null and b/URSA/outputs/eval/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/04_s0_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/04_s0_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d5ac0780c70d100930d20ab32988a6f9943f4965 Binary files /dev/null and b/URSA/outputs/eval/04_s0_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 b/URSA/outputs/eval/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1f4c97088a7a77fa722bb3d98493099ebe2b9a2a Binary files /dev/null and b/URSA/outputs/eval/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 b/URSA/outputs/eval/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b3ae89192a36f6e5e6aaaa628c2c1984ff10ed3d Binary files /dev/null and b/URSA/outputs/eval/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 b/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ced4c6daf952ec4ecab2bc6cd418c3a35faedd81 Binary files /dev/null and b/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 b/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..49576040c542e5718e1a873814e1a77870803bc7 Binary files /dev/null and b/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 b/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ced4c6daf952ec4ecab2bc6cd418c3a35faedd81 Binary files /dev/null and b/URSA/outputs/eval/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 b/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f3c8b0e4b59c70b01a58d3bc4c97755b8820cab4 Binary files /dev/null and b/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 b/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e7eeb72b54c81e6abcf1a1a69441eb7db663712c Binary files /dev/null and b/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cb567a8d60e9a52e7872d56c871debe1de1015e8 Binary files /dev/null and b/URSA/outputs/eval/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..246aedca1c66a0dd497f76c0021d1010f8483875 Binary files /dev/null and b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5231e5c2d2cfb7a797adc3cb2efe3ac2d86823c2 Binary files /dev/null and b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fd4c8ab43c4f1a8a4fe3a94203756e328a074391 Binary files /dev/null and b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8332f59fea416297e1f4ba4a982f4d83820f60cb Binary files /dev/null and b/URSA/outputs/eval/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cf8aa158c28e85894e17dcf3172a00e255d25bbe Binary files /dev/null and b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2c6646ad5fadaa32900879ed351c1e9af4963448 Binary files /dev/null and b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..081fb8a6cf33a1a85dedb19e641b101e6c874f53 Binary files /dev/null and b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e43e239de35b2367a15c5f99bd334b2401f94834 Binary files /dev/null and b/URSA/outputs/eval/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..02c4cc2010ae7877b9b7ea332a9af1483a50f6e8 Binary files /dev/null and b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..638b5e76e52837a9d9553f1d4783d6747fc3b145 Binary files /dev/null and b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..02c4cc2010ae7877b9b7ea332a9af1483a50f6e8 Binary files /dev/null and b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4bf89ea61d6e48039c64e41ccb854e7e8eaf3e62 Binary files /dev/null and b/URSA/outputs/eval/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4b9c044da65f051aaf7b1aeaf4f116336604e7d5 Binary files /dev/null and b/URSA/outputs/eval/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b66bd4e4e571a9390fdd14ea9043e8aad2a51b6c Binary files /dev/null and b/URSA/outputs/eval/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 b/URSA/outputs/eval/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0123e84bcabfdd3f869277ad0e54d50ba592f8b7 Binary files /dev/null and b/URSA/outputs/eval/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 b/URSA/outputs/eval/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b227820907f7e98743d985d0c35d38dc594ea03e Binary files /dev/null and b/URSA/outputs/eval/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..16fad11e24f09c7da5dc83caa79c11e030fdb8ce Binary files /dev/null and b/URSA/outputs/eval/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 b/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3a9a0eb65368834c6313b5e1f680f15e61a19ac6 Binary files /dev/null and b/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 b/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3a9a0eb65368834c6313b5e1f680f15e61a19ac6 Binary files /dev/null and b/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..814ee83c1d26341db2ac1468cddca1d9ed2402eb Binary files /dev/null and b/URSA/outputs/eval/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 b/URSA/outputs/eval/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2c411d9d247393d605e8b66f38ee5e6bd34efd47 Binary files /dev/null and b/URSA/outputs/eval/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 b/URSA/outputs/eval/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..db432526d265622751079c2779ac5047c12f0711 Binary files /dev/null and b/URSA/outputs/eval/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 b/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9767f1b6ca9ad4a2e200a71d93e5dd74df152b9a Binary files /dev/null and b/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 b/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4ede540c402088cffd7f1c391bdd5687328b5f00 Binary files /dev/null and b/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 b/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5f37c7872550956975f455270034e156a900a6ae Binary files /dev/null and b/URSA/outputs/eval/07_s0_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 b/URSA/outputs/eval/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2c746c1080135f5c75a235b96731803bddc6bbf8 Binary files /dev/null and b/URSA/outputs/eval/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 b/URSA/outputs/eval/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ed186f013b7ce13fca1afff5bfeb6d89061b5be6 Binary files /dev/null and b/URSA/outputs/eval/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 b/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..86f142f5518f4eb40032ac5060b64bc566dea007 Binary files /dev/null and b/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 b/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4510631cb3164ccbb725b7cb96dfc92eb2c42d0c Binary files /dev/null and b/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 b/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..86f142f5518f4eb40032ac5060b64bc566dea007 Binary files /dev/null and b/URSA/outputs/eval/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 b/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6a7b330c63692b7915a521f7dd3b519aa4a59e6c Binary files /dev/null and b/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 b/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e7f37e4b97a337f8ab0dc4d13f54119060212c5d Binary files /dev/null and b/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7f8db14ffdcc7062ed881d9108256939da606a3d Binary files /dev/null and b/URSA/outputs/eval/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0abab762e2ea13078188d854a7f3e605e1cd8e41 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ea781816ca5d4540c48f1d211f882d5b3a5b0ffd Binary files /dev/null and b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0abab762e2ea13078188d854a7f3e605e1cd8e41 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bf654003900d3d4e375fc6853d4d3d399b536876 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s0_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..743bfff3624adcfdb5789d09c0b8dc7f4ec0e4e7 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..04b713db46508ea780800908ca8f59f01ab48d10 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..743bfff3624adcfdb5789d09c0b8dc7f4ec0e4e7 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..00feee8f06cdb5af036739649f100ba6daf971ca Binary files /dev/null and b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..18bad28ae1579ff44aa80c168cfa995ca017b3d6 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s1_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..43b2679dd9c57568ad131baa340bca9f00c085db Binary files /dev/null and b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2fee4ed00bbba2109cafeafb61890ebbe071aab6 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..43b2679dd9c57568ad131baa340bca9f00c085db Binary files /dev/null and b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3e40c724669e7fff234cd72619dc4f04c6eb82a8 Binary files /dev/null and b/URSA/outputs/eval_distill/00_s2_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c68786de7c4a19983faadc7c320a0af1a43646cd Binary files /dev/null and b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f9dcc1ff8026d0cae82afb81ab094485e11d6b8e Binary files /dev/null and b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c68786de7c4a19983faadc7c320a0af1a43646cd Binary files /dev/null and b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..32b30a1ffa685716039550b3258ca81a5803ad1c Binary files /dev/null and b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2ca77b4108163ba2075a05723620adc7baac46dd Binary files /dev/null and b/URSA/outputs/eval_distill/00_s3_a_lone_grizzly_bear_walks_through_a_mist_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d31055e720664385e0ae641989c92dbd4b9bfe07 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..da25d568c415391fcc97f607df95a264121407c5 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d31055e720664385e0ae641989c92dbd4b9bfe07 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6e6fffdf8717adfffca12091495dc648dfb82838 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s0_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4a261d42049798c8b73bafcd76691c393ace0154 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..317baf4f9e85e075f4a7c5f1e15d9cde1cc9d540 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4a261d42049798c8b73bafcd76691c393ace0154 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..97a292af395eafdad5fd1a83ac6da0d7f34c0e90 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s1_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e3f71ad99d6c7c871105c288470cdc6f726a59ee Binary files /dev/null and b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8e7b388d81298e710f166069b4dfa203cfaebec6 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e3f71ad99d6c7c871105c288470cdc6f726a59ee Binary files /dev/null and b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..71a17e50cd81415983026a329418dc0a454143b3 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s2_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ec0da5d29b9c1a20e3e1faff44359a798ddd2a7f Binary files /dev/null and b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..dd4a0b7a2436c871e0280392c309e0fe24661cd7 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ec0da5d29b9c1a20e3e1faff44359a798ddd2a7f Binary files /dev/null and b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..23257068f980949d52c7778fcd92c3cde6def3f5 Binary files /dev/null and b/URSA/outputs/eval_distill/01_s3_beautiful_fireworks_in_the_sky_with_red__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..060e74bdd88c7a091f79e954da393fbc35a51706 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8a4693ecc435d6639aa9d65c48582a0008c616ea Binary files /dev/null and b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..060e74bdd88c7a091f79e954da393fbc35a51706 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..acbf74d14ac6dda8157e8601dca967918b5a57f2 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s0_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..82c5b59e32a7946f5e8fd74fadaeb6ab2c7b4eb7 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e98dd861a57572a43170dfd53a930d30fc425af5 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..82c5b59e32a7946f5e8fd74fadaeb6ab2c7b4eb7 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..eadd987cdf4662b2d58a5c396548319c85c5fd9e Binary files /dev/null and b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0a7204a5a3611c33a93788d4821d68216896511d Binary files /dev/null and b/URSA/outputs/eval_distill/02_s1_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6a057f4b46fc04dcdf11fd0621f7a1ee232ba682 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..27c46f567f94548ca9c1861ebf2acc5394c31485 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6a057f4b46fc04dcdf11fd0621f7a1ee232ba682 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3bf7d6e610f433c5f8bd75c602d5b86b38f867f6 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..84127b03dcea417ade2f17b39176ac92a46cf550 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s2_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6bcd399b655438a6091e8eb3d6b9b2be8e439249 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3559e8626bcd6de7b9fc62e7f4a5ad1d903584e9 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6bcd399b655438a6091e8eb3d6b9b2be8e439249 Binary files /dev/null and b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c82fd100a2e3bf14fda3eaa2242e5347b931ec8f Binary files /dev/null and b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3583c6cecd961ea0182a5feebe5b6eacd47a0a2b Binary files /dev/null and b/URSA/outputs/eval_distill/02_s3_a_wave_crashes_on_a_rocky_shoreline_at_s_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fe9a76e428549bf10a44905028ef21fd2714de0e Binary files /dev/null and b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2a0cd2e805231fc1dcd48a7f96a79b2d2e6a3070 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fe9a76e428549bf10a44905028ef21fd2714de0e Binary files /dev/null and b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..308c920468deb7fec32708478e88f660238dd4db Binary files /dev/null and b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8959c1b1a97752704310a56b08068d8bf839dd2f Binary files /dev/null and b/URSA/outputs/eval_distill/03_s0_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fcb242c51414eef0d40c130994ea2c52f83ef8bd Binary files /dev/null and b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2a8628dd06b1495f549368595adc371315d8e202 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fcb242c51414eef0d40c130994ea2c52f83ef8bd Binary files /dev/null and b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..97084e1e6de97a168ffc231baf4a336e2824e1dd Binary files /dev/null and b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2deda2d17c78633550fc5f5a6504755c1c17191c Binary files /dev/null and b/URSA/outputs/eval_distill/03_s1_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 b/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5b4f7aae7e97b51bb6c8bbecc0d571cc74841599 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ffe0d960bcd2ff3f20163ad66e0d918ae2338426 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5b4f7aae7e97b51bb6c8bbecc0d571cc74841599 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s2_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e11ddd93239bf28e9da2c60e6d9fed21b6144949 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..19aebc7196a5793672ed1bb594813f5078b7c022 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e11ddd93239bf28e9da2c60e6d9fed21b6144949 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e2e03b17a53fa1022d59c3e57ffe1f9a6471db1b Binary files /dev/null and b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e4a05f0c411fdf6956683a2497e1adadccf76431 Binary files /dev/null and b/URSA/outputs/eval_distill/03_s3_a_hummingbird_hovers_in_front_of_a_red_f_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d25ba443f970ad58e07cd6600487e77e61242b1b Binary files /dev/null and b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..40d4620e65fec163b421f428f811ff115ba9e39a Binary files /dev/null and b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d25ba443f970ad58e07cd6600487e77e61242b1b Binary files /dev/null and b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d5ac0780c70d100930d20ab32988a6f9943f4965 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s0_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1a325815e8334dedd8f1ba624c64e03ad1424ad1 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..90f316e2ae8553ab99bf34ff5dcbf5b4f46bbb91 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1a325815e8334dedd8f1ba624c64e03ad1424ad1 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1fa1011db29f943612773ffaeb6a7e1702c4c822 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..dc453d9e5d028dc7fe96489921e422722b8d8a1c Binary files /dev/null and b/URSA/outputs/eval_distill/04_s1_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 b/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c699c9da7debc6e8737217e4981da503dccc2e47 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ace4fba2d61a7fd1c8f08a05b5366595fb66ca0a Binary files /dev/null and b/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c699c9da7debc6e8737217e4981da503dccc2e47 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s2_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..edb794f8d3c11d34eab50d19012e75040d824aaf Binary files /dev/null and b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..67992a58fd249add3b32bd101571e8f916721e33 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..edb794f8d3c11d34eab50d19012e75040d824aaf Binary files /dev/null and b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c6608ab594a9e42c258d7e2d2af3e6fc37ea171c Binary files /dev/null and b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cb567a8d60e9a52e7872d56c871debe1de1015e8 Binary files /dev/null and b/URSA/outputs/eval_distill/04_s3_timelapse_of_clouds_rolling_over_mountai_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..17bb9959acac5fb948e60367d674038b106a09b4 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..787cb38c661a4989721b68d42642202076561cc8 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..17bb9959acac5fb948e60367d674038b106a09b4 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c1b97ec6d4c78b5bb7807f7c57524667d81b5c85 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8332f59fea416297e1f4ba4a982f4d83820f60cb Binary files /dev/null and b/URSA/outputs/eval_distill/05_s0_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2fdf2d31468fd2b5d0ffb64c87b5cc4248e1fa62 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9dfa1effe64300c76ea3a7f291da23b28fca4aef Binary files /dev/null and b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2fdf2d31468fd2b5d0ffb64c87b5cc4248e1fa62 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7a95ff6eb08931ca052e54e9cefe582ba869fbcf Binary files /dev/null and b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e43e239de35b2367a15c5f99bd334b2401f94834 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s1_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e3ad271883769aba2fdb7cbcc7ecc7cc15ffa16e Binary files /dev/null and b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..11329eb0ae8280fbc19b5a8a7e64a727b97f5e76 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..e3ad271883769aba2fdb7cbcc7ecc7cc15ffa16e Binary files /dev/null and b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7a37fd8b2b45444683aebdd7cd8a709c3a8138c3 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6f1e1224e0f7898fdf41024f44cda63f6e8ca916 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s2_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cca73702bd3c7e0e1743711fba9500eb5cf4f2c9 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..009fe6ae97fe5e102048785853470941a9160b6d Binary files /dev/null and b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..cca73702bd3c7e0e1743711fba9500eb5cf4f2c9 Binary files /dev/null and b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8863578bf2b7c369eac5a88aa5b56de79bf892eb Binary files /dev/null and b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b66bd4e4e571a9390fdd14ea9043e8aad2a51b6c Binary files /dev/null and b/URSA/outputs/eval_distill/05_s3_a_neonlit_city_street_at_night_with_rain_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..87df35b0bd01c15eda6c4817b75c340186f919d9 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..89c76e44a72a53880cf353d43576696955d635be Binary files /dev/null and b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..87df35b0bd01c15eda6c4817b75c340186f919d9 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..309d90215f9fa75ed1a78ee5f18a69ebfe5d4065 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8b3b5944d7650a513d88d3b2230ecb08b3c64af1 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s0_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2a2a8b4e4df902650a1aab1a603f3de96ee81688 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b0c3d0ada615f3672b57cdb209fdec88424b10b6 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2a2a8b4e4df902650a1aab1a603f3de96ee81688 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f23a6523e398a87b01bcbf73aaff0525f9e16756 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..16fad11e24f09c7da5dc83caa79c11e030fdb8ce Binary files /dev/null and b/URSA/outputs/eval_distill/06_s1_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8e9845a0d693ca3db025d94afaf2a122f98bb9e5 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f25b42b83cd4cc843042298b53c9693af47494fc Binary files /dev/null and b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8e9845a0d693ca3db025d94afaf2a122f98bb9e5 Binary files /dev/null and b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..814ee83c1d26341db2ac1468cddca1d9ed2402eb Binary files /dev/null and b/URSA/outputs/eval_distill/06_s2_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..caf9f264d017deabdfc45539a382a29def9a626a Binary files /dev/null and b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..b1a710920c3207bc4f698ea2236953331a58695d Binary files /dev/null and b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..caf9f264d017deabdfc45539a382a29def9a626a Binary files /dev/null and b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ecc5dbf656c4137e6c654f7e4117c8c6813c5a4d Binary files /dev/null and b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..410103f9e37e217b5c12107c1162fa6e81f2679e Binary files /dev/null and b/URSA/outputs/eval_distill/06_s3_a_kitten_playing_with_a_ball_of_yarn_on__teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f6edfd30291c1ef4305ee26ae7116116eb18c718 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..aec035468fca0d9d2e3a238a5d30620cb1191f0d Binary files /dev/null and b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f6edfd30291c1ef4305ee26ae7116116eb18c718 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0bd20339a4a3249dc0a550ad60b4decc89829480 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s0_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5f85c66b8b3e39af02a086ce977c911be6ca0d74 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..83ce88cd214460f1bb26180793db404909236727 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5f85c66b8b3e39af02a086ce977c911be6ca0d74 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..396a813b9a8d3e6dc999bd124cf2cab65823d448 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3674770c65fbb8cd8cacc78155ad1019bf8bfecd Binary files /dev/null and b/URSA/outputs/eval_distill/07_s1_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 b/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..69a0f8e17d46a2c9b4b3b695002a5817da69d647 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..72ebcdddc44d21c7a5ad12751d4135c9c9cf5f37 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..69a0f8e17d46a2c9b4b3b695002a5817da69d647 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s2_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a1a014e33d868be1bbaf9c3fcd075b2418dc1244 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_baked.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6e62fe9d021b534ac23dc5c07ddf77db41f69736 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a1a014e33d868be1bbaf9c3fcd075b2418dc1244 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_student_1step_no_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8d12243650bad3d8236798bc84b41586896a8578 Binary files /dev/null and b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_cfg.mp4 differ diff --git a/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7f8db14ffdcc7062ed881d9108256939da606a3d Binary files /dev/null and b/URSA/outputs/eval_distill/07_s3_astronaut_floating_weightlessly_inside_a_teacher_25step_no_cfg.mp4 differ diff --git a/URSA/scripts/ab_test_inference.py b/URSA/scripts/ab_test_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c392e118c8525ad00fcdee0fc0985ddfc9e8aa --- /dev/null +++ b/URSA/scripts/ab_test_inference.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +"""A/B test: official URSA inference vs eval_distill_dimo inference. + +This script runs the EXACT same pipeline call in two ways: + A) "official" — follows README Quick Start verbatim + B) "eval" — follows eval_distill_dimo.py logic + +Both use the same pipeline instance, same prompt, same seed. +Saves side-by-side outputs + prints every intermediate diagnostic. + +Usage: + python scripts/ab_test_inference.py \ + --model /gfs/space/private/fengzl/World_Model/URSA-1.7B \ + --device 0 + +This will generate: + outputs/ab_test/official_t2i.jpg + outputs/ab_test/official_t2v.mp4 + outputs/ab_test/eval_teacher_cfg.mp4 + outputs/ab_test/eval_teacher_nocfg.mp4 + outputs/ab_test/eval_student_*.mp4 (if --student_ckpt given) +""" + +import argparse +import os +import sys + +import numpy as np +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from diffnext.pipelines import URSAPipeline +from diffnext.utils import export_to_image, export_to_video + + +def parse_args(): + p = argparse.ArgumentParser(description="A/B test URSA inference") + p.add_argument("--model", required=True, help="URSA model path") + p.add_argument("--student_ckpt", default=None, help="Optional student.pt") + p.add_argument("--device", type=int, default=0) + p.add_argument("--precision", default="float16", choices=["float16", "bfloat16"]) + p.add_argument("--out_dir", default="./outputs/ab_test") + # Test different resolutions — FSQ320 native is 320x512 + p.add_argument("--test_resolutions", nargs="+", default=["320x512"], + help="Resolutions to test as HxW strings (FSQ320 native: 320x512)") + p.add_argument("--test_steps", nargs="+", type=int, default=[25, 50], + help="Inference steps to test") + p.add_argument("--num_frames", type=int, default=49) + return p.parse_args() + + +def diag(label, obj): + """Print diagnostic.""" + print(f" [{label}] {obj}") + + +def diag_tensor(label, t): + """Print tensor diagnostics.""" + if isinstance(t, torch.Tensor): + print(f" [{label}] shape={t.shape} dtype={t.dtype} device={t.device} " + f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.mean().item():.4f}") + elif isinstance(t, np.ndarray): + print(f" [{label}] shape={t.shape} dtype={t.dtype} " + f"min={t.min()} max={t.max()} mean={t.mean():.2f}") + + +def diag_pipeline(pipe): + """Full pipeline diagnostic.""" + print("\n" + "=" * 70) + print(" PIPELINE DIAGNOSTICS") + print("=" * 70) + print(f" pipeline class : {type(pipe).__name__}") + print(f" transformer class : {type(pipe.transformer).__name__}") + print(f" transformer device : {next(pipe.transformer.parameters()).device}") + print(f" transformer dtype : {next(pipe.transformer.parameters()).dtype}") + print(f" vae class : {type(pipe.vae).__name__}") + print(f" vae device : {next(pipe.vae.parameters()).device}") + print(f" scheduler class : {type(pipe.scheduler).__name__}") + print(f" scheduler repr : {repr(pipe.scheduler)}") + + sched = pipe.scheduler + if hasattr(sched, 'path') and sched.path is not None: + print(f" scheduler.path class: {type(sched.path).__name__}") + if hasattr(sched.path, 'emb'): + emb = sched.path.emb + print(f" path.emb shape : {emb.shape}") + print(f" path.emb device : {emb.device}") + print(f" path.emb dtype : {emb.dtype}") + print(f" path.emb[0,:5] : {emb[0,:5].tolist()}") + if hasattr(sched.path, 'alpha'): + print(f" path.alpha : {getattr(sched.path, 'alpha', 'N/A')}") + if hasattr(sched.path, 'c'): + print(f" path.c : {getattr(sched.path, 'c', 'N/A')}") + else: + print(f" scheduler.path : MISSING or None!") + + print(f" codebook_size : {getattr(sched, 'codebook_size', 'N/A')}") + print(f" shift : {getattr(sched, 'shift', 'N/A')}") + + if hasattr(sched, 'config'): + print(f" scheduler.config : {dict(sched.config)}") + + print(f" vae_temporal_stride : {getattr(pipe, 'vae_temporal_stride', 'N/A')}") + print(f" vae_spatial_stride : {getattr(pipe, 'vae_spatial_stride', 'N/A')}") + print(f" tokenizer class : {type(pipe.tokenizer).__name__}") + print("=" * 70 + "\n") + + +def diag_output(frames_output, label): + """Diagnose pipeline output.""" + print(f"\n --- Output diagnostics: {label} ---") + if isinstance(frames_output, np.ndarray): + diag_tensor(f"{label} raw", frames_output) + elif isinstance(frames_output, list): + print(f" [{label}] list of {len(frames_output)} items") + if len(frames_output) > 0: + f0 = frames_output[0] + if isinstance(f0, np.ndarray): + diag_tensor(f"{label}[0]", f0) + else: + print(f" [{label}[0]] type={type(f0)}") + else: + print(f" [{label}] type={type(frames_output)}") + + +def save_frames(frames, path, fps=12): + """Save frames as video or image.""" + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + if path.endswith(".mp4"): + if isinstance(frames, np.ndarray) and frames.ndim == 4: + export_to_video(list(frames), output_video_path=path, fps=fps) + elif isinstance(frames, list): + export_to_video(frames, output_video_path=path, fps=fps) + else: + export_to_video(frames, output_video_path=path, fps=fps) + elif path.endswith((".jpg", ".png")): + from PIL import Image + if isinstance(frames, np.ndarray): + Image.fromarray(frames).save(path) + elif hasattr(frames, 'save'): + frames.save(path) + + +def main(): + args = parse_args() + os.makedirs(args.out_dir, exist_ok=True) + + dtype = getattr(torch, args.precision) + device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu") + + prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur." + negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly" + seed = 42 + + # ===================================================================== + # Load pipeline + # ===================================================================== + print(f"\n[1] Loading pipeline from {args.model} ...") + pipe = URSAPipeline.from_pretrained( + args.model, torch_dtype=dtype, trust_remote_code=True + ).to(device) + + diag_pipeline(pipe) + + # ===================================================================== + # Test A: Official README T2V (exact copy from README for FSQ320) + # FSQ320: height=320, width=512, num_frames=49, steps=50 + # ===================================================================== + print("\n" + "#" * 70) + print("# TEST A: Official README T2V (FSQ320 native resolution)") + print("#" * 70) + + gen = torch.Generator(device=device).manual_seed(seed) + out = pipe( + prompt=f"motion=9.0, {prompt}", + negative_prompt=negative_prompt, + height=320, + width=512, + num_frames=49, + num_inference_steps=50, + guidance_scale=7, + generator=gen, + output_type="np", + ) + frames = out.frames + diag_output(frames, "A_official_t2v") + if isinstance(frames, np.ndarray): + video_frames = frames[0] if frames.ndim == 5 else frames + else: + video_frames = frames + path_a = os.path.join(args.out_dir, "A_official_t2v_320x512_49f_50step.mp4") + try: + if isinstance(video_frames, np.ndarray): + export_to_video(list(video_frames), output_video_path=path_a, fps=12) + else: + export_to_video(video_frames, output_video_path=path_a, fps=12) + print(f" Saved: {path_a}") + except Exception as e: + print(f" Failed: {e}") + + # Also test T2I at native resolution (1 frame) + print("\n# TEST A2: T2I at 320x512 (1 frame)") + gen = torch.Generator(device=device).manual_seed(seed) + out = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=320, + width=512, + num_frames=1, + num_inference_steps=25, + guidance_scale=7, + generator=gen, + ) + image = out.frames[0] + path_a2 = os.path.join(args.out_dir, "A_official_t2i_320x512.jpg") + if hasattr(image, 'save'): + image.save(path_a2) + print(f" Saved: {path_a2} (PIL Image)") + else: + diag_output(out.frames, "A2_t2i") + + # ===================================================================== + # Test B: Different resolutions and step counts for video + # ===================================================================== + for res_str in args.test_resolutions: + h, w = map(int, res_str.split("x")) + for steps in args.test_steps: + for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]: + label = f"B_{h}x{w}_{steps}step_{gs_label}" + print(f"\n{'#' * 70}") + print(f"# TEST {label}") + print(f"# height={h} width={w} num_frames={args.num_frames}") + print(f"# steps={steps} guidance_scale={gs_val}") + print(f"{'#' * 70}") + + gen = torch.Generator(device=device).manual_seed(seed) + neg = negative_prompt if gs_val > 1 else None + + # Print scheduler state before call + print(f" scheduler.codebook_size = {pipe.scheduler.codebook_size}") + print(f" scheduler.path type = {type(pipe.scheduler.path).__name__}") + + out = pipe( + prompt=prompt, + negative_prompt=neg, + height=h, + width=w, + num_frames=args.num_frames, + num_inference_steps=steps, + guidance_scale=gs_val, + guidance_trunc=0.9, + max_prompt_length=320, + vae_batch_size=1, + output_type="np", + generator=gen, + ) + + frames = out.frames + diag_output(frames, label) + + # For video output (num_frames > 1), frames is [batch, T, H, W, 3] + if isinstance(frames, np.ndarray): + if frames.ndim == 5: + video_frames = frames[0] # [T, H, W, 3] + elif frames.ndim == 4: + video_frames = frames # [T, H, W, 3] + else: + video_frames = frames + elif isinstance(frames, list): + video_frames = frames + else: + video_frames = frames + + path = os.path.join(args.out_dir, f"{label}.mp4") + try: + if isinstance(video_frames, np.ndarray): + export_to_video(list(video_frames), output_video_path=path, fps=12) + else: + export_to_video(video_frames, output_video_path=path, fps=12) + print(f" Saved: {path}") + except Exception as e: + print(f" Failed to save {path}: {e}") + + # ===================================================================== + # Test C: Student (if provided) + # ===================================================================== + if args.student_ckpt: + print(f"\n{'#' * 70}") + print(f"# TEST C: Student 1-step") + print(f"{'#' * 70}") + + teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} + student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True) + + print(f" student keys: {len(student_state)}") + print(f" teacher keys: {len(teacher_state)}") + + # Check key compatibility + missing = set(teacher_state.keys()) - set(student_state.keys()) + extra = set(student_state.keys()) - set(teacher_state.keys()) + if missing: + print(f" WARNING: {len(missing)} keys in teacher but not student: {list(missing)[:5]}") + if extra: + print(f" WARNING: {len(extra)} keys in student but not teacher: {list(extra)[:5]}") + + pipe.transformer.load_state_dict(student_state, strict=True) + pipe.transformer.eval() + + for res_str in args.test_resolutions[:1]: # Just first resolution + h, w = map(int, res_str.split("x")) + for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]: + label = f"C_student_{h}x{w}_1step_{gs_label}" + gen = torch.Generator(device=device).manual_seed(seed) + neg = negative_prompt if gs_val > 1 else None + + out = pipe( + prompt=prompt, + negative_prompt=neg, + height=h, + width=w, + num_frames=args.num_frames, + num_inference_steps=1, + guidance_scale=gs_val, + guidance_trunc=0.9, + max_prompt_length=320, + vae_batch_size=1, + output_type="np", + generator=gen, + ) + + frames = out.frames + diag_output(frames, label) + + if isinstance(frames, np.ndarray): + video_frames = frames[0] if frames.ndim == 5 else frames + else: + video_frames = frames + + path = os.path.join(args.out_dir, f"{label}.mp4") + try: + if isinstance(video_frames, np.ndarray): + export_to_video(list(video_frames), output_video_path=path, fps=12) + else: + export_to_video(video_frames, output_video_path=path, fps=12) + print(f" Saved: {path}") + except Exception as e: + print(f" Failed to save {path}: {e}") + + # Restore teacher + pipe.transformer.load_state_dict(teacher_state, strict=True) + + print(f"\n[DONE] All outputs in {args.out_dir}") + print("\nCheck these files to diagnose blurriness:") + print(" - A_official_t2i_1024x1024.jpg → should be sharp (official T2I)") + print(" - B_*_cfg7.mp4 → teacher video with CFG") + print(" - B_*_nocfg.mp4 → teacher video without CFG") + print(" - Compare different resolutions and step counts") + print(" - If ALL are blurry, the issue is in pipeline/scheduler/VAE loading") + print(" - If only low-res are blurry, it's a resolution issue") + print(" - If only low-step are blurry, need more steps") + + +if __name__ == "__main__": + main() diff --git a/URSA/scripts/app_ursa_t2i.py b/URSA/scripts/app_ursa_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..03b67657c823a3765dd61d52f248d964f8a63c5f --- /dev/null +++ b/URSA/scripts/app_ursa_t2i.py @@ -0,0 +1,147 @@ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------ +"""URSA T2I application.""" + +import argparse +import os + +import gradio as gr +import numpy as np +import torch + +from diffnext.pipelines import URSAPipeline +from diffnext.utils import export_to_image + +# Switch to the allocator optimized for dynamic shape. +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + +def parse_args(): + """Parse arguments.""" + parser = argparse.ArgumentParser(description="Serve URSA T2I application") + parser.add_argument("--model", default="", help="model path") + parser.add_argument("--device", type=int, default=0, help="device index") + parser.add_argument("--precision", default="float16", help="compute precision") + return parser.parse_args() + + +def generate_image( + prompt, + negative_prompt, + seed, + randomize_seed, + width, + height, + guidance_scale, + num_inference_steps, +): + """Generate an image.""" + args = locals() + seed = np.random.randint(2147483647) if randomize_seed else seed + device = getattr(pipe, "_offload_device", pipe.device) + generator = torch.Generator(device=device).manual_seed(seed) + images = pipe(generator=generator, **args).frames + return [export_to_image(image, quality=95) for image in images] + [seed] + + +css = """#col-container {margin: 0 auto; max-width: 1366px}""" +title = "Uniform Discrete Diffusion with Metric Path for Video Generation" +header = ( + "
" + "

Uniform Discrete Diffusion with Metric Path for Video Generation

" + "

[paper]" + "[code]

" + "
" +) + +examples = [ + "a selfie of an old man with a white beard.", + "a woman with long hair next to a luminescent bird.", + "a digital artwork of a cat styled in a whimsical fashion. The overall vibe is quirky and artistic.", # noqa + "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", + "a beautiful afghan women by red hair and green eyes.", + "beautiful fireworks in the sky with red, white and blue.", + "A dragon perched majestically on a craggy, smoke-wreathed mountain.", + "A photo of llama wearing sunglasses standing on the deck of a spaceship with the Earth in the background.", # noqa + "Two pandas in fluffy slippers and bathrobes, lazily munching on bamboo.", +] + + +if __name__ == "__main__": + args = parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu", args.device) + model_args = {"torch_dtype": getattr(torch, args.precision.lower()), "trust_remote_code": True} + pipe = URSAPipeline.from_pretrained(args.model, **model_args).to(device) + + # Main Application. + app = gr.Blocks(css=css, theme="origin").__enter__() + container = gr.Column(elem_id="col-container").__enter__() + _, main_row = gr.Markdown(header), gr.Row().__enter__() + + # Input. + input_col = gr.Column().__enter__() + prompt = gr.Text( + label="Prompt", + placeholder="Describe the video you want to generate", + value="A lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", # noqa + lines=5, + ) + negative_prompt = gr.Text( + label="Negative Prompt", + placeholder="Describe what you don't want in the image", + value="worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly", # noqa + lines=5, + ) + # fmt: off + options = gr.Accordion("Options", open=False).__enter__() + seed = gr.Slider(label="Seed", maximum=2147483647, step=1, value=0) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + guidance_scale = gr.Slider(label="Guidance scale", minimum=1, maximum=10, step=0.1, value=7) + with gr.Row(): + width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=1024) + height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=1024) + num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=25) # noqa + options.__exit__() + generate_btn = gr.Button("Generate Image", variant="primary", size="lg") + input_col.__exit__() + # fmt: on + + # Results. + result = gr.Image(label="Result", height=720, show_label=False) + main_row.__exit__() + + # Examples. + with gr.Row(): + gr.Examples(examples=examples, inputs=[prompt]) + + # Events. + container.__exit__() + gr.on( + triggers=[generate_btn.click, prompt.submit, negative_prompt.submit], + fn=generate_image, + inputs=[ + prompt, + negative_prompt, + seed, + randomize_seed, + width, + height, + guidance_scale, + num_inference_steps, + ], + outputs=[result, seed], + ) + app.__exit__(), app.launch(share=False) diff --git a/URSA/scripts/app_ursa_ti2v.py b/URSA/scripts/app_ursa_ti2v.py new file mode 100644 index 0000000000000000000000000000000000000000..9b22c26ba76890795934454b6466352db31c82c1 --- /dev/null +++ b/URSA/scripts/app_ursa_ti2v.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------ +"""URSA TI2V application.""" + +import argparse +import os + +import gradio as gr +import numpy as np +import PIL.Image +import torch + +from diffnext.pipelines import URSAPipeline +from diffnext.utils import export_to_image, export_to_video + +# Fix tokenizer fork issue. +os.environ["TOKENIZERS_PARALLELISM"] = "true" +# Switch to the allocator optimized for dynamic shape. +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + +def parse_args(): + """Parse arguments.""" + parser = argparse.ArgumentParser(description="Serve URSA TI2V application") + parser.add_argument("--model", default="", help="model path") + parser.add_argument("--device", type=int, default=0, help="device index") + parser.add_argument("--precision", default="float16", help="compute precision") + return parser.parse_args() + + +def crop_image(image, target_h, target_w): + """Center crop image to target size.""" + h, w = image.height, image.width + aspect_ratio_target, aspect_ratio = target_w / target_h, w / h + if aspect_ratio > aspect_ratio_target: + new_w = int(h * aspect_ratio_target) + x_start = (w - new_w) // 2 + image = image.crop((x_start, 0, x_start + new_w, h)) + else: + new_h = int(w / aspect_ratio_target) + y_start = (h - new_h) // 2 + image = image.crop((0, y_start, w, y_start + new_h)) + return np.array(image.resize((target_w, target_h), PIL.Image.Resampling.BILINEAR)) + + +def generate_image( + prompt, + negative_prompt, + seed, + randomize_seed, + guidance_scale, + num_inference_steps=25, +): + """Generate a video.""" + args = {**locals(), **video_presets["t2i"]} + seed = np.random.randint(2147483647) if randomize_seed else seed + device = getattr(pipe, "_offload_device", pipe.device) + generator = torch.Generator(device=device).manual_seed(seed) + images = pipe(generator=generator, **args).frames + return [export_to_image(image, quality=95) for image in images] + [seed] + + +def generate_video( + prompt, + negative_prompt, + image, + motion_score, + seed, + randomize_seed, + guidance_scale, + num_inference_steps, + output_type="np", +): + """Generate a video.""" + args = {**locals(), **video_presets["ti2v"]} + args["prompt"] = f"motion={motion_score:.1f}, {prompt}" + args["image"] = crop_image(image, args["height"], args["width"]) if image else None + seed = np.random.randint(2147483647) if randomize_seed else seed + device = getattr(pipe, "_offload_device", pipe.device) + generator = torch.Generator(device=device).manual_seed(seed) + frames = pipe(generator=generator, **args).frames[0] + return export_to_video(frames, fps=12), seed + + +css = """#col-container {margin: 0 auto; max-width: 1366px}""" +title = "Uniform Discrete Diffusion with Metric Path for Video Generation" +header = ( + "
" + "

Uniform Discrete Diffusion with Metric Path for Video Generation

" + "

[paper]" + "[code]

" + "
" +) + +video_presets = { + "t2i": {"width": 512, "height": 320, "num_frames": 1}, + "ti2v": {"width": 512, "height": 320, "num_frames": 49}, +} + +prompts = [ + "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", + "Many spotted jellyfish pulsating under water. Their bodies are transparent and glowing in deep ocean.", # noqa + "An intense close-up of a soldier’s face, covered in dirt and sweat, his eyes filled with determination as he surveys the battlefield.", # noqa + "a close-up shot of a woman standing in a dimly lit room. she is wearing a traditional chinese outfit, which includes a red and gold dress with intricate designs and a matching headpiece. the woman has her hair styled in an updo, adorned with a gold accessory. her makeup is done in a way that accentuates her features, with red lipstick and dark eyeshadow. she is looking directly at the camera with a neutral expression. the room has a rustic feel, with wooden beams and a stone wall visible in the background. the lighting in the room is soft and warm, creating a contrast with the woman's vibrant attire. there are no texts or other objects in the video. the style of the video is a portrait, focusing on the woman and her attire.", # noqa + "The camera slowly rotates around a massive stack of vintage televisions that are placed within a large New York museum gallery. Each of the televisions is showing a different program. There are 1950s sci-fi movies with their distinctive visuals, horror movies with their creepy scenes, news broadcasts with moving images and words, static on some screens, and a 1970s sitcom with its characteristic look. The televisions are of various sizes and designs, some with rounded edges and others with more angular shapes. The gallery is well-lit, with light falling on the stack of televisions and highlighting the different programs being shown. There are no people visible in the immediate vicinity, only the stack of televisions and the surrounding gallery space.", # noqa +] +motion_scores = [9, 9, 9, 9, 9] +videos = ["", "", "", "", ""] +examples = [list(x) for x in zip(prompts, motion_scores)] + + +if __name__ == "__main__": + args = parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu", args.device) + model_args = {"torch_dtype": getattr(torch, args.precision.lower()), "trust_remote_code": True} + pipe = URSAPipeline.from_pretrained(args.model, **model_args).to(device) + + # Application. + app = gr.Blocks(css=css, theme="origin").__enter__() + container = gr.Column(elem_id="col-container").__enter__() + _, main_row = gr.Markdown(header), gr.Row().__enter__() + + # Input. + input_col = gr.Column().__enter__() + prompt = gr.Text( + label="Prompt", + placeholder="Describe the video you want to generate", + value="A lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", # noqa + lines=5, + ) + negative_prompt = gr.Text( + label="Negative Prompt", + placeholder="Describe what you don't want in the video", + value="worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly", # noqa + lines=1, + ) + with gr.Row(): + generate_image_btn = gr.Button("Generate Image Prompt", variant="primary", size="lg") + generate_video_btn = gr.Button("Generate Video", variant="primary", size="lg") + image_prompt = gr.Image(label="Image Prompt", height=480, type="pil") + + # fmt: off + options = gr.Accordion("Options", open=False).__enter__() + seed = gr.Slider(label="Seed", maximum=2147483647, step=1, value=0) + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + guidance_scale = gr.Slider(label="Guidance scale", minimum=1, maximum=10.0, step=0.1, value=7.0) + with gr.Row(): + num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=50) # noqa + options.__exit__(), input_col.__exit__() + + # Results. + result_col = gr.Column().__enter__() + motion = gr.Slider(label="Motion Score", minimum=1, maximum=10, step=1, value=9) + result = gr.Video(label="Result", height=480, show_label=False, autoplay=True) + result_col.__exit__(), main_row.__exit__() + # fmt: on + + # Examples. + with gr.Row(): + gr.Examples(examples=examples, inputs=[prompt, motion]) + + # Events. + container.__exit__() + gr.on( + triggers=[generate_image_btn.click, prompt.submit, negative_prompt.submit], + fn=generate_image, + inputs=[ + prompt, + negative_prompt, + seed, + randomize_seed, + guidance_scale, + ], + outputs=[image_prompt, seed], + ) + gr.on( + triggers=[generate_video_btn.click, prompt.submit, negative_prompt.submit], + fn=generate_video, + inputs=[ + prompt, + negative_prompt, + image_prompt, + motion, + seed, + randomize_seed, + guidance_scale, + num_inference_steps, + ], + outputs=[result, seed], + ) + app.__exit__(), app.launch(share=False) diff --git a/URSA/scripts/eval_distill_dimo.py b/URSA/scripts/eval_distill_dimo.py new file mode 100644 index 0000000000000000000000000000000000000000..62906de50450f4bdfb32097fe139597adf3cfb03 --- /dev/null +++ b/URSA/scripts/eval_distill_dimo.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------- +"""Evaluation script for distill_dimo checkpoints. + +Generates videos from both the student (1-step) and teacher (multi-step) +using checkpoints saved by train_distill_dimo.py. + +Verified native inference regime (from A/B testing — ground truth): + height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. + no_cfg (guidance_scale=1) does NOT produce valid output for this URSA + checkpoint — outputs are blank or blurry. + +Student generation modes +------------------------ + cfg : 1-step, guidance_scale=7 (2× forward, inference-time CFG) + +Teacher generation modes +------------------------ + cfg : 50-step, guidance_scale=7 (official working regime) + +Usage: + python scripts/eval_distill_dimo.py \ + --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B \ + --student_ckpt ./experiments/distill_dimo_v3/checkpoints/checkpoint-200/student.pt \ + --out_dir ./outputs/eval_distill_v3_200steps_49frames +""" + +import argparse +import os +import sys + +import numpy as np +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from diffnext.pipelines import URSAPipeline +from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import ( + VERIFIED_NATIVE_DEFAULTS, + check_verified_regime, +) +from diffnext.utils import export_to_video + + +# --------------------------------------------------------------------------- +# Default prompts and seeds +# --------------------------------------------------------------------------- + +DEFAULT_PROMPTS = [ + "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", + "beautiful fireworks in the sky with red, white and blue.", + "a wave crashes on a rocky shoreline at sunset, slow motion.", + "a hummingbird hovers in front of a red flower, wings a blur.", + "timelapse of clouds rolling over mountain peaks.", + "a neon-lit city street at night with rain-soaked reflections.", + "a kitten playing with a ball of yarn on a wooden floor.", + "astronaut floating weightlessly inside a space station.", +] + +# Official URSA negative prompt (from README / app scripts) +DEFAULT_NEGATIVE_PROMPT = ( + "worst quality, low quality, inconsistent motion, static, still, " + "blurry, jittery, distorted, ugly" +) + +DEFAULT_SEEDS = [0, 1, 2, 3] + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser( + description="Evaluate distill_dimo student (1-step) vs teacher (multi-step)" + ) + + p.add_argument("--teacher_ckpt", required=True, + help="URSA diffusers pipeline directory (teacher weights)") + p.add_argument("--student_ckpt", required=True, + help="student.pt from train_distill_dimo.py checkpoint") + p.add_argument("--out_dir", default="./outputs/eval_distill") + + # Geometry — verified native: 320×512×49 (from A/B testing) + p.add_argument("--num_frames", type=int, default=49) + p.add_argument("--height", type=int, default=320) + p.add_argument("--width", type=int, default=512) + p.add_argument("--fps", type=int, default=12) + + # Student generation — default: cfg only (no_cfg is known to fail) + p.add_argument("--student_modes", nargs="+", default=["cfg"], + choices=["no_cfg", "cfg", "baked"], + help="Student generation modes to evaluate. " + "Default: ['cfg']. no_cfg is known to produce blank/blurry " + "output for this checkpoint.") + p.add_argument("--eval_cfg_scale", type=float, default=7.0, + help="Guidance scale for 'cfg' mode (verified working value=7)") + + # Teacher generation — default: cfg only (no_cfg is known to fail) + p.add_argument("--teacher_modes", nargs="+", default=["cfg"], + choices=["no_cfg", "cfg"], + help="Teacher generation modes. Default: ['cfg']. " + "no_cfg is NOT a valid baseline for this URSA checkpoint.") + p.add_argument("--teacher_steps", type=int, default=50, + help="Number of inference steps for teacher (verified default=50)") + + # Shared generation params (match verified official defaults) + p.add_argument("--guidance_trunc", type=float, default=0.9, + help="Truncation threshold for inference CFG") + p.add_argument("--negative_prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, + help="Negative prompt for CFG (official URSA uses one)") + p.add_argument("--max_prompt_length", type=int, default=320) + p.add_argument("--vae_batch_size", type=int, default=1) + + # Safety override for no_cfg + p.add_argument("--allow_bad_nocfg", action="store_true", default=False, + help="Suppress the no_cfg warning/block. Use at your own risk.") + + # Data + p.add_argument("--prompt_file", default=None, + help="Text file with one prompt per line (overrides defaults)") + p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS) + + # Device + p.add_argument("--device", type=int, default=0) + p.add_argument("--mixed_precision", default="bf16", + choices=["fp16", "bf16", "fp32"]) + + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def slug(text: str, max_len: int = 40) -> str: + s = text.lower() + s = "".join(c if c.isalnum() or c == " " else "" for c in s) + s = "_".join(s.split())[:max_len] + return s or "prompt" + + +def frames_to_mp4(frames, path: str, fps: int = 12): + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + if isinstance(frames, np.ndarray) and frames.ndim == 4: + frames = list(frames) + export_to_video(frames, output_video_path=path, fps=fps) + + +def _extract_frames(frames_output): + """Normalise pipeline output → list of uint8 numpy arrays [H, W, 3].""" + if isinstance(frames_output, np.ndarray): + frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output + frames = list(frames_output) + elif isinstance(frames_output, list): + frames = [np.array(f) if not isinstance(f, np.ndarray) else f + for f in frames_output] + else: + raise TypeError(f"Unexpected frames type: {type(frames_output)}") + result = [] + for f in frames: + if f.dtype != np.uint8: + f = ((f * 255).clip(0, 255).astype(np.uint8) + if f.max() <= 1.0 else f.astype(np.uint8)) + result.append(f) + return result + + +def _gen(pipe, prompt, negative_prompt, seed, num_frames, height, width, + guidance_scale, num_inference_steps, guidance_trunc, + max_prompt_length, vae_batch_size, device): + """Single generation call, returns list of uint8 frames.""" + gen = torch.Generator(device=device).manual_seed(seed) + with torch.inference_mode(): + out = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + guidance_trunc=guidance_trunc, + max_prompt_length=max_prompt_length, + vae_batch_size=vae_batch_size, + output_type="np", + generator=gen, + ) + return _extract_frames(out.frames) + + +def _debug_pipeline(pipe, label=""): + """Print diagnostic info about the pipeline state.""" + print(f"\n{'='*60}") + print(f" Pipeline diagnostics {label}") + print(f"{'='*60}") + print(f" scheduler class : {type(pipe.scheduler).__name__}") + print(f" scheduler type : {type(pipe.scheduler)}") + if hasattr(pipe.scheduler, 'config'): + print(f" scheduler.config : {dict(pipe.scheduler.config)}") + if hasattr(pipe.scheduler, 'path'): + print(f" scheduler.path : {type(pipe.scheduler.path).__name__}") + if hasattr(pipe.scheduler.path, 'emb'): + emb = pipe.scheduler.path.emb + print(f" path.emb shape : {emb.shape}") + print(f" path.emb device : {emb.device}") + print(f" path.emb dtype : {emb.dtype}") + else: + print(f" scheduler.path : MISSING (scheduler not fully loaded!)") + print(f" codebook_size : {getattr(pipe.scheduler, 'codebook_size', 'N/A')}") + print(f" transformer class : {type(pipe.transformer).__name__}") + print(f" transformer device : {next(pipe.transformer.parameters()).device}") + print(f" vae class : {type(pipe.vae).__name__}") + if hasattr(pipe, 'image_processor'): + print(f" image_processor : {type(pipe.image_processor).__name__}") + print(f"{'='*60}\n") + + +def _debug_frames(frames, label=""): + """Print diagnostic info about generated frames.""" + if not frames: + print(f" [{label}] No frames generated!") + return + f0 = frames[0] + print(f" [{label}] n_frames={len(frames)} shape={f0.shape} " + f"dtype={f0.dtype} min={f0.min()} max={f0.max()}") + + +def _verify_state_dict_swap(pipe, state_dict, label=""): + """Verify transformer weights actually changed after load_state_dict.""" + sample_key = next(iter(state_dict.keys())) + loaded_val = state_dict[sample_key].flatten()[:8] + current_val = pipe.transformer.state_dict()[sample_key].flatten()[:8] + match = torch.allclose(loaded_val.cpu().float(), current_val.cpu().float(), atol=1e-6) + print(f" [{label}] state_dict match for '{sample_key}': {match}") + if not match: + print(f" loaded : {loaded_val[:4]}") + print(f" current : {current_val[:4]}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + args = parse_args() + + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + dtype = dtype_map[args.mixed_precision] + device = (torch.device("cuda", args.device) + if torch.cuda.is_available() else torch.device("cpu")) + os.makedirs(args.out_dir, exist_ok=True) + + # -- Verified regime check -------------------------------------------- + is_native = check_verified_regime( + height=args.height, + width=args.width, + num_frames=args.num_frames, + guidance_scale=args.eval_cfg_scale, + teacher_steps=args.teacher_steps, + label="eval", + ) + print(f"[eval] verified_native_regime={is_native}") + print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), " + f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}") + + # -- no_cfg safety gate ----------------------------------------------- + all_modes = list(args.student_modes) + list(args.teacher_modes) + if "no_cfg" in all_modes: + if args.allow_bad_nocfg: + print("[WARN] no_cfg is known to fail for this URSA checkpoint. " + "Outputs may be blank or blurry. Proceeding because --allow_bad_nocfg is set.") + else: + print("[WARN] no_cfg is known to fail for this URSA checkpoint. " + "Outputs may be blank or blurry. " + "Pass --allow_bad_nocfg to override this warning.") + + # -- Load prompts ----------------------------------------------------- + if args.prompt_file: + with open(args.prompt_file, encoding="utf-8") as f: + prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")] + else: + prompts = DEFAULT_PROMPTS + + print(f"[eval] {len(prompts)} prompts × {len(args.seeds)} seeds " + f"| student modes={args.student_modes} " + f"| teacher modes={args.teacher_modes}") + print(f"[eval] guidance_scale={args.eval_cfg_scale} " + f"guidance_trunc={args.guidance_trunc} " + f"teacher_steps={args.teacher_steps}") + print(f"[eval] negative_prompt='{args.negative_prompt[:60]}...'") + + # -- Load pipeline (teacher) ------------------------------------------ + print(f"[eval] Loading pipeline from {args.teacher_ckpt} …") + # 【修改点 2】尝试启用 Flash Attention 2 + try: + pipe = URSAPipeline.from_pretrained( + args.teacher_ckpt, + torch_dtype=dtype, + trust_remote_code=True, + attn_implementation="flash_attention_2" + ).to(device) + except Exception: + # 如果环境不支持 FA2,降级到默认 + pipe = URSAPipeline.from_pretrained( + args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True + ).to(device) + + if hasattr(pipe.vae, "disable_slicing"): + pipe.vae.disable_slicing() + if hasattr(pipe.vae, "disable_tiling"): + pipe.vae.disable_tiling() + + # print("[eval] Compiling transformer (this takes ~2 mins for the first time)...") + # pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead") + + # Diagnostic: verify scheduler loaded correctly + _debug_pipeline(pipe, label="after from_pretrained + .to(device)") + + # CRITICAL CHECK: scheduler must have .path with embeddings + if not hasattr(pipe.scheduler, 'path') or pipe.scheduler.path is None: + print("[ERROR] Scheduler path not loaded! This will cause blurry output.") + print("[ERROR] The scheduler needs scheduler_model.pth with codebook embeddings.") + return + + if pipe.scheduler.codebook_size == 0: + print("[ERROR] codebook_size=0 — scheduler not properly initialized!") + return + + # Save teacher state for switching back after student inference + teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} + + # -- Load student checkpoint ------------------------------------------ + print(f"[eval] Loading student weights from {args.student_ckpt} …") + student_state = torch.load( + args.student_ckpt, map_location=device, weights_only=True + ) + print(f"[eval] student state_dict keys: {len(student_state)} " + f"sample key: {next(iter(student_state.keys()))}") + + # Common kwargs for every pipeline call + gen_kwargs = dict( + num_frames=args.num_frames, + height=args.height, + width=args.width, + guidance_trunc=args.guidance_trunc, + max_prompt_length=args.max_prompt_length, + vae_batch_size=args.vae_batch_size, + ) + + # Mode → guidance_scale mapping + student_guidance = { + "no_cfg": 1.0, + "cfg": args.eval_cfg_scale, + "baked": 1.0, + } + teacher_guidance = { + "no_cfg": 1.0, + "cfg": args.eval_cfg_scale, + } + + # -- Evaluation loop -------------------------------------------------- + for idx, prompt in enumerate(prompts): + p_slug = slug(prompt) + print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}") + + for seed in args.seeds: + # ---- Student: 1-step generation ----------------------------- + for mode in args.student_modes: + g_scale = student_guidance[mode] + neg = args.negative_prompt if g_scale > 1 else None + pipe.transformer.load_state_dict(student_state, strict=True) + pipe.transformer.eval() + + if idx == 0 and seed == args.seeds[0]: + _verify_state_dict_swap(pipe, student_state, f"student/{mode}") + + with torch.no_grad(): + frames = _gen(pipe, prompt, neg, seed, + guidance_scale=g_scale, + num_inference_steps=1, + device=device, **gen_kwargs) + + if idx == 0 and seed == args.seeds[0]: + _debug_frames(frames, f"student/{mode}") + + path = os.path.join( + args.out_dir, + f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4", + ) + frames_to_mp4(frames, path, fps=args.fps) + print(f" [student/{mode:6s}] seed={seed} scale={g_scale} → {path}") + + # ---- Teacher: multi-step reference -------------------------- + for t_mode in args.teacher_modes: + g_scale = teacher_guidance[t_mode] + neg = args.negative_prompt if g_scale > 1 else None + pipe.transformer.load_state_dict(teacher_state, strict=True) + pipe.transformer.eval() + + if idx == 0 and seed == args.seeds[0]: + _verify_state_dict_swap(pipe, teacher_state, f"teacher/{t_mode}") + + with torch.no_grad(): + frames = _gen(pipe, prompt, neg, seed, + guidance_scale=g_scale, + num_inference_steps=args.teacher_steps, + device=device, **gen_kwargs) + + if idx == 0 and seed == args.seeds[0]: + _debug_frames(frames, f"teacher/{t_mode}") + + path = os.path.join( + args.out_dir, + f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4", + ) + frames_to_mp4(frames, path, fps=args.fps) + print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} " + f"steps={args.teacher_steps} → {path}") + + print(f"\n[eval] Done. Results in {args.out_dir}") + _print_guide(args) + + +def _print_guide(args): + print(f""" +╔══════════════════════════════════════════════════════════════╗ +║ Interpretation guide ║ +╠══════════════════════════════════════════════════════════════╣ +║ student_1step_cfg : 1-step, guidance_scale={args.eval_cfg_scale:<4} ║ +║ (verified working student mode) ║ +║ student_1step_baked : 1-step, guidance_scale=1 ║ +║ (for students trained with CFG KD) ║ +║ teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step, guidance_scale={args.eval_cfg_scale:<4} ║ +║ (verified working teacher mode) ║ +╠══════════════════════════════════════════════════════════════╣ +║ NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline ║ +║ for this URSA checkpoint. Use --allow_bad_nocfg to override.║ +╚══════════════════════════════════════════════════════════════╝""") + + +if __name__ == "__main__": + main() diff --git a/URSA/scripts/eval_onestep_ursa.py b/URSA/scripts/eval_onestep_ursa.py new file mode 100644 index 0000000000000000000000000000000000000000..00a500938cbff72c7a800a998c8923120db3b2ed --- /dev/null +++ b/URSA/scripts/eval_onestep_ursa.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------- +"""Evaluation script: compare student 1-step variants vs multi-step teacher. + +Verified native inference regime (from A/B testing — ground truth): + height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. + no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint. + +Student generation modes +------------------------ + cfg : 1-step, guidance_scale=7 (verified working student mode) + baked : 1-step, guidance_scale=1 (for students trained with CFG KD) + +Teacher generation modes +------------------------ + cfg : 50-step, guidance_scale=7 (verified working teacher mode) + +Usage: + python scripts/eval_onestep_ursa.py \\ + --teacher_ckpt /path/to/URSA \\ + --student_ckpt ./outputs/dimo/final/student.pt \\ + --modes cfg \\ + --eval_cfg_scale 7.0 \\ + --num_frames 49 --height 320 --width 512 \\ + --teacher_steps 50 \\ + --out_dir ./outputs/eval +""" + +import argparse +import os +import sys + +import numpy as np +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from diffnext.pipelines import URSAPipeline +from diffnext.utils import export_to_video + + +# --------------------------------------------------------------------------- +# Default prompts and seeds +# --------------------------------------------------------------------------- + +DEFAULT_PROMPTS = [ + "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", + "beautiful fireworks in the sky with red, white and blue.", + "a wave crashes on a rocky shoreline at sunset, slow motion.", + "a hummingbird hovers in front of a red flower, wings a blur.", + "timelapse of clouds rolling over mountain peaks.", + "a neon-lit city street at night with rain-soaked reflections.", + "a kitten playing with a ball of yarn on a wooden floor.", + "astronaut floating weightlessly inside a space station.", +] + +DEFAULT_SEEDS = [0, 1, 2, 3] + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="URSA 1-step student eval vs teacher") + + p.add_argument("--teacher_ckpt", required=True, help="URSA diffusers pipeline dir") + p.add_argument("--student_ckpt", required=True, + help="student.pt checkpoint from train_onestep_ursa_dimo.py") + p.add_argument("--out_dir", default="./outputs/eval") + + # Geometry (verified native: 320×512×49) + p.add_argument("--num_frames", type=int, default=49) + p.add_argument("--height", type=int, default=320) + p.add_argument("--width", type=int, default=512) + p.add_argument("--fps", type=int, default=12) + + # Generation — default: cfg only (no_cfg is known to fail) + p.add_argument("--modes", nargs="+", default=["cfg"], + choices=["no_cfg", "cfg", "baked"], + help="Student generation modes. Default: ['cfg']. " + "no_cfg is known to produce blank/blurry output.") + p.add_argument("--eval_cfg_scale", type=float, default=7.0, + help="Guidance scale for 'cfg' mode (verified working value=7)") + p.add_argument("--teacher_steps", type=int, default=50, + help="Inference steps for teacher (verified default=50)") + p.add_argument("--teacher_modes", nargs="+", default=["cfg"], + choices=["no_cfg", "cfg"], + help="Teacher modes. Default: ['cfg']. " + "no_cfg is NOT a valid baseline for this checkpoint.") + p.add_argument("--guidance_trunc", type=float, default=0.9, + help="Truncation threshold for inference CFG (passed to pipeline)") + p.add_argument("--max_prompt_length", type=int, default=320) + p.add_argument("--vae_batch_size", type=int, default=1) + + # Data + p.add_argument("--prompt_file", default=None, + help="Optional: text file with one prompt per line") + p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS) + + # Device + p.add_argument("--device", type=int, default=0) + p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"]) + + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def slug(text: str, max_len: int = 40) -> str: + s = text.lower() + s = "".join(c if c.isalnum() or c == " " else "" for c in s) + s = "_".join(s.split())[:max_len] + return s or "prompt" + + +def frames_to_mp4(frames, path: str, fps: int = 12): + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + if isinstance(frames, np.ndarray) and frames.ndim == 4: + frames = list(frames) + export_to_video(frames, output_video_path=path, fps=fps) + + +def _extract_frames(frames_output): + """Normalise pipeline output → list of uint8 numpy arrays [H, W, 3].""" + if isinstance(frames_output, np.ndarray): + frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output + frames = list(frames_output) + elif isinstance(frames_output, list): + frames = [np.array(f) if not isinstance(f, np.ndarray) else f for f in frames_output] + else: + raise TypeError(f"Unexpected frames type: {type(frames_output)}") + result = [] + for f in frames: + if f.dtype != np.uint8: + f = (f * 255).clip(0, 255).astype(np.uint8) if f.max() <= 1.0 else f.astype(np.uint8) + result.append(f) + return result + + +DEFAULT_NEGATIVE_PROMPT = ( + "worst quality, low quality, inconsistent motion, static, still, " + "blurry, jittery, distorted, ugly" +) + + +def _gen(pipe, prompt, seed, num_frames, height, width, guidance_scale, + num_inference_steps, guidance_trunc, max_prompt_length, vae_batch_size, + device, negative_prompt=None): + """Single generation call, returns list of uint8 frames.""" + gen = torch.Generator(device=device).manual_seed(seed) + out = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + guidance_trunc=guidance_trunc, + max_prompt_length=max_prompt_length, + vae_batch_size=vae_batch_size, + output_type="np", + generator=gen, + ) + return _extract_frames(out.frames) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + args = parse_args() + + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + dtype = dtype_map[args.mixed_precision] + device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu") + os.makedirs(args.out_dir, exist_ok=True) + + # -- Verified regime validation ---------------------------------------- + _NATIVE = dict(height=320, width=512, num_frames=49, guidance_scale=7.0, teacher_steps=50) + is_native = ( + args.height == _NATIVE["height"] + and args.width == _NATIVE["width"] + and args.num_frames == _NATIVE["num_frames"] + and args.eval_cfg_scale == _NATIVE["guidance_scale"] + and args.teacher_steps == _NATIVE["teacher_steps"] + ) + print(f"[eval] verified_native_regime={is_native}") + print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), " + f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}") + if not is_native: + print(f"[WARN] Current config deviates from the verified native URSA regime " + f"({_NATIVE['num_frames']}×{_NATIVE['height']}×{_NATIVE['width']}, " + f"cfg={_NATIVE['guidance_scale']}, steps={_NATIVE['teacher_steps']}).") + + all_modes = list(args.modes) + list(args.teacher_modes) + if "no_cfg" in all_modes: + print("[WARN] no_cfg is known to fail for this URSA checkpoint. " + "Outputs may be blank or blurry.") + + # -- Load prompts ----------------------------------------------------- + if args.prompt_file: + with open(args.prompt_file, encoding="utf-8") as f: + prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")] + else: + prompts = DEFAULT_PROMPTS + + print(f"[eval] {len(prompts)} prompts × {len(args.seeds)} seeds " + f"| student modes={args.modes} | teacher modes={args.teacher_modes}") + + # -- Load pipeline --------------------------------------------------- + print(f"[eval] Loading pipeline from {args.teacher_ckpt} …") + pipe = URSAPipeline.from_pretrained( + args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True + ).to(device) + + # -- Load student checkpoint ----------------------------------------- + print(f"[eval] Loading student weights from {args.student_ckpt} …") + student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True) + teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} + + # Common kwargs passed to every pipeline call + gen_kwargs = dict( + num_frames=args.num_frames, + height=args.height, + width=args.width, + guidance_trunc=args.guidance_trunc, + max_prompt_length=args.max_prompt_length, + vae_batch_size=args.vae_batch_size, + ) + + # Mode → guidance_scale mapping + # no_cfg : single forward, no guidance + # cfg : dual forward, eval_cfg_scale + # baked : single forward, no guidance (student trained with guided KD) + student_guidance = { + "no_cfg": 1.0, + "cfg": args.eval_cfg_scale, + "baked": 1.0, + } + teacher_guidance = { + "no_cfg": 1.0, + "cfg": args.eval_cfg_scale, + } + + # -- Evaluation loop ------------------------------------------------- + for idx, prompt in enumerate(prompts): + p_slug = slug(prompt) + print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}") + + for seed in args.seeds: + # ---- Student: selected modes -------------------------------- + for mode in args.modes: + g_scale = student_guidance[mode] + neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None + pipe.transformer.load_state_dict(student_state, strict=True) + pipe.transformer.eval() + + with torch.no_grad(): + frames = _gen(pipe, prompt, seed, + guidance_scale=g_scale, + num_inference_steps=1, + negative_prompt=neg, + device=device, **gen_kwargs) + + path = os.path.join( + args.out_dir, + f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4", + ) + frames_to_mp4(frames, path, fps=args.fps) + print(f" [student/{mode:6s}] seed={seed} scale={g_scale} → {path}") + + # ---- Teacher: reference videos ------------------------------ + for t_mode in args.teacher_modes: + g_scale = teacher_guidance[t_mode] + neg = DEFAULT_NEGATIVE_PROMPT if g_scale > 1 else None + pipe.transformer.load_state_dict(teacher_state, strict=True) + pipe.transformer.eval() + + with torch.no_grad(): + frames = _gen(pipe, prompt, seed, + guidance_scale=g_scale, + num_inference_steps=args.teacher_steps, + negative_prompt=neg, + device=device, **gen_kwargs) + + path = os.path.join( + args.out_dir, + f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4", + ) + frames_to_mp4(frames, path, fps=args.fps) + print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} " + f"steps={args.teacher_steps} → {path}") + + print(f"\n[eval] Done. Results in {args.out_dir}") + _print_interpretation_guide(args) + + +def _print_interpretation_guide(args): + print(f""" +╔══════════════════════════════════════════════════════════════╗ +║ Interpretation guide for generated videos ║ +╠══════════════════════════════════════════════════════════════╣ +║ student_1step_cfg : 1-step + CFG={args.eval_cfg_scale:<4} ║ +║ (verified working student mode) ║ +║ student_1step_baked : 1-step, guidance_scale=1 ║ +║ (for students trained with CFG KD) ║ +║ teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step + CFG={args.eval_cfg_scale:<4} ║ +║ (verified working teacher mode) ║ +╠══════════════════════════════════════════════════════════════╣ +║ NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline ║ +║ for this URSA checkpoint — outputs are blank or blurry. ║ +╚══════════════════════════════════════════════════════════════╝""") + + +if __name__ == "__main__": + main() diff --git a/URSA/scripts/test_patches_mock.py b/URSA/scripts/test_patches_mock.py new file mode 100644 index 0000000000000000000000000000000000000000..2cfbf39f14cc046864420486cd1c3886015c0e94 --- /dev/null +++ b/URSA/scripts/test_patches_mock.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +"""Self-contained mock test for all 6 patches in train_onestep_ursa_dimo.py. + +Does NOT require loading the real URSA pipeline. +Exercises: + (1) Batch-concat [2B] forward — verified via forward call counts + (2) reward / adv detach — runtime assertions + (3) _stable_kl / _stable_jeffrey (float32 + log_softmax) + (4) Separate loss_aux_cond / loss_aux_uncond / loss_kd_cond / loss_kd_uncond logging + (5) use_guided per-sample shape [B] and ratio + (6) flex_attn offsets probe / reset + +Run: + python scripts/test_patches_mock.py +""" +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import types, copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Import helpers from the training script directly +import importlib.util +spec = importlib.util.spec_from_file_location( + "train", os.path.join(os.path.dirname(__file__), "train_onestep_ursa_dimo.py")) +train_mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(train_mod) + +_stable_kl = train_mod._stable_kl +_stable_jeffrey = train_mod._stable_jeffrey +_build_guided_logits = train_mod._build_guided_logits +_select_target = train_mod._select_target +_cfg_warmup_prob = train_mod._cfg_warmup_prob +_compute_cfg_scale = train_mod._compute_cfg_scale +_probe_flex_attn = train_mod._probe_flex_attn +_reset_flex_attn = train_mod._reset_flex_attn +_print_flex_attn_state = train_mod._print_flex_attn_state +_token_histogram_entropy = train_mod._token_histogram_entropy + +print("=" * 70) +print("URSA distillation patch self-test (mock)") +print("=" * 70) + +device = torch.device("cpu") +B, N, K = 2, 12, 64 # small numbers for speed + +# ========================================================================= +# Patch (3): _stable_kl / _stable_jeffrey — float32 + log_softmax +# ========================================================================= +print("\n[3] Testing _stable_kl / _stable_jeffrey …") +torch.manual_seed(0) +z_p = torch.randn(B, N, K) +z_q = torch.randn(B, N, K) + +kl_pq = _stable_kl(z_p, z_q) +kl_qp = _stable_kl(z_q, z_p) +jeff = _stable_jeffrey(z_p, z_q) + +assert kl_pq.shape == (B,), f"kl_pq shape={kl_pq.shape}" +assert (kl_pq >= 0).all(), "KL must be non-negative" +assert (kl_qp >= 0).all(), "KL must be non-negative (reverse)" +assert torch.allclose(jeff, kl_pq + kl_qp, atol=1e-5), "Jeffrey ≠ KL(p||q) + KL(q||p)" +assert not torch.isnan(kl_pq).any(), "kl_pq has NaN" +assert not torch.isinf(kl_pq).any(), "kl_pq has Inf" + +# KL(p||p) == 0 +kl_pp = _stable_kl(z_p, z_p) +assert kl_pp.abs().max() < 1e-5, f"KL(p||p) should be ~0, got {kl_pp}" + +# Numerics with large logits (simulate s=3 amplification) +z_large = z_p * 50.0 +kl_large = _stable_kl(z_large, z_q) +assert not torch.isnan(kl_large).any(), "kl_large has NaN with large logits" +assert not torch.isinf(kl_large).any(), "kl_large has Inf with large logits" + +print(f" kl_pq = {kl_pq.tolist()} (both ≥0 ✓)") +print(f" jeffrey= {jeff.tolist()} (= kl_pq + kl_qp ✓)") +print(f" kl(p,p)= {kl_pp.tolist()} (≈0 ✓)") +print(f" kl with z*50: {kl_large.tolist()} (finite ✓)") +print("[3] _stable_kl / _stable_jeffrey PASSED ✓") + +# ========================================================================= +# Patch (3b): _build_guided_logits — float32, per-sample scale +# ========================================================================= +print("\n[3b] Testing _build_guided_logits …") +z_cond = torch.randn(B, N, K) +z_uncond = torch.randn(B, N, K) +t = torch.tensor([0.3, 0.95]) # one below, one above trunc=0.9 +z_guided = _build_guided_logits(z_cond, z_uncond, t, cfg_scale=3.0, trunc=0.9) + +assert z_guided.shape == (B, N, K), f"z_guided.shape={z_guided.shape}" +assert not torch.isnan(z_guided).any(), "z_guided has NaN" +assert not torch.isinf(z_guided).any(), "z_guided has Inf" + +# Sample 0: t=0.3 < trunc → scale=3 +# z_guided[0] = z_uncond[0] + 3*(z_cond[0] - z_uncond[0]) +expected_0 = z_uncond[0] + 3.0 * (z_cond[0] - z_uncond[0]) +assert torch.allclose(z_guided[0], expected_0, atol=1e-5), "sample 0 guided mismatch" +# Sample 1: t=0.95 >= trunc → scale=1 +expected_1 = z_uncond[1] + 1.0 * (z_cond[1] - z_uncond[1]) +assert torch.allclose(z_guided[1], expected_1, atol=1e-5), "sample 1 (trunc) mismatch" + +g_min, g_max, g_mean = z_guided.min().item(), z_guided.max().item(), z_guided.mean().item() +print(f" z_T_guided shape={z_guided.shape} min={g_min:.3f} max={g_max:.3f} mean={g_mean:.3f}") +assert abs(g_min) < 1e4 and abs(g_max) < 1e4, f"guided logits exploded: [{g_min:.1e}, {g_max:.1e}]" +print("[3b] _build_guided_logits PASSED ✓") + +# ========================================================================= +# Patch (5): use_guided per-sample [B] shape + ratio +# ========================================================================= +print("\n[5] Testing per-sample use_guided …") +torch.manual_seed(42) + +# After warmup (step >> warmup_steps) → p = cfg_prob = 1.0 +prob_full = _cfg_warmup_prob(step=10000, cfg_prob=1.0, warmup_steps=2000) +assert abs(prob_full - 1.0) < 1e-6, f"full warmup prob={prob_full}" + +# During warmup at step=1000 with warmup_steps=2000 → p = 0.5 +prob_half = _cfg_warmup_prob(step=1000, cfg_prob=1.0, warmup_steps=2000) +assert abs(prob_half - 0.5) < 1e-6, f"half warmup prob={prob_half}" + +# Per-sample sampling +torch.manual_seed(0) +use_guided = torch.rand(B) < 0.5 # [B] bool +assert use_guided.shape == (B,), f"use_guided.shape={use_guided.shape}" +use_guided_ratio = use_guided.float().mean().item() +print(f" use_guided={use_guided.tolist()} ratio={use_guided_ratio:.2f}") + +# _select_target per-sample +z_target = _select_target(z_guided, z_cond, use_guided) +for b in range(B): + if use_guided[b]: + assert torch.allclose(z_target[b], z_guided[b]), f"sample {b}: guided not selected" + else: + assert torch.allclose(z_target[b], z_cond[b]), f"sample {b}: cond not selected" +print(f" _select_target: per-sample selection correct ✓") +print("[5] Per-sample use_guided PASSED ✓") + +# ========================================================================= +# Patch (1): Batch-concat [2B] — verified via a tiny linear net +# ========================================================================= +print("\n[1] Testing batch-concat [2B] forward equivalence …") + +class TinyModel(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(K, K, bias=False) + self._call_count = 0 + def forward(self, x): + self._call_count += 1 + return self.lin(x.float()) + +model = TinyModel() +x_cond = torch.randn(B, N, K) +x_uncond = torch.randn(B, N, K) + +# Separate forward (old way: 2 calls) +model._call_count = 0 +out_cond_sep = model(x_cond) +out_uncond_sep = model(x_uncond) +calls_sep = model._call_count # = 2 + +# Batch-concat forward (new way: 1 call) +model._call_count = 0 +x_dual = torch.cat([x_cond, x_uncond], dim=0) # [2B, N, K] +out_dual = model(x_dual) # [2B, N, K] +out_cond_bat, out_uncond_bat = out_dual.chunk(2, dim=0) +calls_bat = model._call_count # = 1 + +assert calls_sep == 2, f"sep calls={calls_sep}" +assert calls_bat == 1, f"batch calls={calls_bat}" +assert torch.allclose(out_cond_sep, out_cond_bat, atol=1e-5), "cond output mismatch" +assert torch.allclose(out_uncond_sep, out_uncond_bat, atol=1e-5), "uncond output mismatch" +print(f" Separate: {calls_sep} calls → batch: {calls_bat} call (identical outputs ✓)") +print("[1] Batch-concat forward PASSED ✓") + +# ========================================================================= +# Patch (2): reward / adv detach — no student gradient +# ========================================================================= +print("\n[2] Testing reward/adv detach …") + +z_T = torch.randn(B, N, K).detach() # teacher logits (no grad) +z_S_with_grad = torch.randn(B, N, K, requires_grad=True) # student logits (has grad) + +# Reward computation: z_S must be detached +reward = -_stable_kl(z_T.detach(), z_S_with_grad.detach(), tau=1.0) # [B] +assert not reward.requires_grad, \ + f"[BUG] reward.requires_grad={reward.requires_grad} — gradient leaked" + +baseline_ema = 0.0 +adv = (reward - baseline_ema).detach() +assert not adv.requires_grad, \ + f"[BUG] adv.requires_grad={adv.requires_grad} — detach failed" + +# Verify gradient DOES flow through logp (the differentiable path) +logits_gen = torch.randn(B, N, K, requires_grad=True) +p_gen = F.softmax(logits_gen / 1.0, dim=-1) +x_hat = torch.multinomial(p_gen.view(-1, K).detach(), 1).view(B, N) +logp = p_gen.clamp(1e-8).log().gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) # [B] +loss_pg = -(adv * logp).mean() +loss_pg.backward() +assert logits_gen.grad is not None, "logits_gen has no grad — REINFORCE broken" +assert logits_gen.grad.abs().max() > 0, "logits_gen grad is all zeros" + +print(f" reward.requires_grad={reward.requires_grad} (must be False ✓)") +print(f" adv.requires_grad={adv.requires_grad} (must be False ✓)") +print(f" logits_gen.grad max={logits_gen.grad.abs().max():.4f} (non-zero ✓)") +print("[2] Reward/adv detach PASSED ✓") + +# ========================================================================= +# Patch (4): Separate loss logging keys +# ========================================================================= +print("\n[4] Testing separate loss logging …") + +loss_aux_cond_v = _stable_jeffrey(z_T, z_T + torch.randn_like(z_T) * 0.1, tau=1.0).mean() +loss_aux_uncond_v = _stable_jeffrey(z_T, z_T + torch.randn_like(z_T) * 0.2, tau=1.0).mean() +loss_kd_cond = _stable_kl(z_T, z_S_with_grad, tau=1.0).mean() +loss_kd_uncond_v = _stable_kl(z_T, z_T + torch.randn_like(z_T) * 0.05, tau=1.0).mean() + +log_line = ( + f"[step 1] " + f"loss_aux_cond={loss_aux_cond_v.item():.4f} " + f"loss_aux_uncond={loss_aux_uncond_v.item():.4f} " + f"loss_kd_cond={loss_kd_cond.item():.4f} " + f"loss_kd_uncond={loss_kd_uncond_v.item():.4f} " + f"loss_pg=0.1234 H=3.123 tok_H=4.500 " + f"guided_ratio=0.50 baseline=0.0000 mean_logp=-3.45" +) +print(f" Sample log: {log_line}") +assert "loss_aux_cond=" in log_line +assert "loss_aux_uncond=" in log_line +assert "loss_kd_cond=" in log_line +assert "loss_kd_uncond=" in log_line +assert "guided_ratio=" in log_line +print("[4] Separate loss logging format PASSED ✓") + +# ========================================================================= +# Patch (6): flex_attn offsets probe / reset +# ========================================================================= +print("\n[6] Testing flex_attn probe / reset …") + +# Case A: model without flex_attn +class ModelNoFlex(nn.Module): + pass + +m_no_flex = ModelNoFlex() +fa = _probe_flex_attn(m_no_flex, "no_flex") +assert fa is None, f"Expected None, got {fa}" +_reset_flex_attn(m_no_flex, "no_flex", verbose=True) # should not raise +print(" Model without flex_attn: probe=None, reset is no-op ✓") + +# Case B: model WITH flex_attn — simulate FlexAttentionCausal2D +class FakeFlexAttn: + def __init__(self): + self.offsets = None + self.block_mask = None + self.cu_offsets = None + +class ModelWithFlex(nn.Module): + def __init__(self): + super().__init__() + self.flex_attn = FakeFlexAttn() + +m_flex = ModelWithFlex() +m_flex.flex_attn.offsets = [0, 50, 370] # simulate set offsets +m_flex.flex_attn.block_mask = "some_mask" +m_flex.flex_attn.cu_offsets = torch.tensor([0, 50, 370]) + +print(" Before reset:") +_print_flex_attn_state(m_flex, "test_model") +_reset_flex_attn(m_flex, "test_model", verbose=True) +print(" After reset:") +_print_flex_attn_state(m_flex, "test_model") + +assert m_flex.flex_attn.offsets is None, "offsets not reset" +assert m_flex.flex_attn.block_mask is None, "block_mask not reset" +assert m_flex.flex_attn.cu_offsets is None, "cu_offsets not reset" +print(" flex_attn.offsets=None, block_mask=None, cu_offsets=None ✓") +print("[6] flex_attn probe/reset PASSED ✓") + +# ========================================================================= +# z_T_guided explosion guard (from _run_assertions) +# ========================================================================= +print("\n[3c] Testing z_T_guided explosion guard …") +z_guided_ok = torch.randn(B, N, K) * 10 # normal magnitude +z_guided_bad = torch.randn(B, N, K) * 2e4 # exploded + +assert not torch.isnan(z_guided_ok).any() +assert not torch.isinf(z_guided_ok).any() +assert abs(z_guided_ok.min().item()) < 1e4 + +try: + big_min = z_guided_bad.min().item() + big_max = z_guided_bad.max().item() + assert abs(big_min) < 1e4 and abs(big_max) < 1e4, f"Explosion: [{big_min:.1e}, {big_max:.1e}]" + print(" ⚠️ explosion guard NOT triggered (unexpected)") +except AssertionError as e: + print(f" Explosion guard triggered correctly: {e} ✓") +print("[3c] z_T_guided explosion guard PASSED ✓") + +# ========================================================================= +# Token histogram entropy +# ========================================================================= +print("\n[misc] Testing _token_histogram_entropy …") +# Uniform: entropy = log(K) +x_uniform = torch.randint(0, K, (1, B * N)) +H_uniform = _token_histogram_entropy(x_uniform, K) +print(f" uniform entropy={H_uniform:.3f} log(K)={K ** 0 * torch.tensor(K).float().log().item():.3f}") + +# Collapsed: all tokens = 0 → entropy = 0 +x_collapsed = torch.zeros(1, B * N, dtype=torch.long) +H_collapsed = _token_histogram_entropy(x_collapsed, K) +assert H_collapsed < 0.01, f"collapsed entropy={H_collapsed} should be ~0" +print(f" collapsed entropy={H_collapsed:.4f} (≈0 ✓)") +print("[misc] _token_histogram_entropy PASSED ✓") + +# ========================================================================= +# Patch (7): extract_visual_logits — manual reconstruction +# ========================================================================= +print("\n[7] extract_visual_logits end-to-end alignment (mock) …") +import importlib.util as _ilu, sys as _sys +_spec = _ilu.spec_from_file_location( + "_utils", os.path.join(os.path.dirname(__file__), "..", "src", "distill", "utils_ursa_inputs.py")) +_utils = _ilu.module_from_spec(_spec) +_spec.loader.exec_module(_utils) +extract_visual_logits = _utils.extract_visual_logits + +# Case A: D == K (URSA default — lm_head outputs K logits directly) +B7, N7, K7 = 1, 20, 64 +L7 = 8 +logits_full_A = torch.randn(B7, L7 + N7 + 1, K7) # D == K +z_vis_A = extract_visual_logits(logits_full_A, N7, K7) +z_seq_A = logits_full_A[:, -(N7+1):-1] # raw causal slice [B, N, D=K] +delta_A = (z_vis_A - z_seq_A).abs().max().item() +assert delta_A < 1e-6, f"Case A (D==K) delta={delta_A}" +print(f" [7a] D={K7}==K: extract == raw slice, delta={delta_A:.2e} ✓") + +# Case B: D > K (lm_head larger than codebook — offset=D-K) +D7B = K7 + 10 +logits_full_B = torch.randn(B7, L7 + N7 + 1, D7B) +z_vis_B = extract_visual_logits(logits_full_B, N7, K7) +z_seq_B = logits_full_B[:, -(N7+1):-1] # [B, N, D] +z_man_B = z_seq_B[..., D7B - K7:] # [B, N, K] +delta_B = (z_vis_B - z_man_B).abs().max().item() +assert delta_B < 1e-6, f"Case B (D>K) delta={delta_B}" +print(f" [7b] D={D7B}>K={K7}: extract == z[..., D-K:], delta={delta_B:.2e} ✓") + +# Case C: latent_shift test (D >= latent_shift + K — full-vocab head) +latent_shift_C = 12 +D7C = latent_shift_C + K7 +logits_full_C = torch.randn(B7, L7 + N7 + 1, D7C) +# extract_visual_logits with D7C == D7C: D == K? No, D7C=76, K7=64, D>K +# internal: offset = D7C - K7 = 12 = latent_shift_C → should match [..., latent_shift_C:] +z_vis_C = extract_visual_logits(logits_full_C, N7, K7) +z_seq_C = logits_full_C[:, -(N7+1):-1] +z_man_C1 = z_seq_C[..., latent_shift_C:] # using latent_shift as offset +z_man_C2 = z_seq_C[..., D7C - K7:] # using D-K as offset (same) +assert torch.allclose(z_man_C1, z_man_C2), "C1 != C2" +delta_C = (z_vis_C - z_man_C1).abs().max().item() +assert delta_C < 1e-6, f"Case C (full-vocab) delta={delta_C}" +print(f" [7c] D={D7C}=latent_shift+K: extract == z[..., latent_shift:], delta={delta_C:.2e} ✓") +print("[7] extract_visual_logits alignment PASSED ✓") + +# ========================================================================= +# Patch (8): flex_attn semantics sanity (mock — no real model) +# ========================================================================= +print("\n[8] flex_attn semantics sanity (mock) …") +# Verify that _reset_flex_attn clears offsets and block_mask + +class FakeFlexAttn2: + def __init__(self): + self.offsets = [0, 50, 370] + self.block_mask = "mask_obj" + self.cu_offsets = torch.tensor([0, 50, 370]) + def set_offsets_by_lens(self, lens): + from itertools import accumulate + self.offsets = list(accumulate([0] + lens)) + self.block_mask = None + +class ModelFlex2: + def __init__(self): + self.flex_attn = FakeFlexAttn2() + +m8 = ModelFlex2() +print(f" [8] before reset: offsets={m8.flex_attn.offsets}") +_reset_flex_attn(m8, "m8", verbose=True) +assert m8.flex_attn.offsets is None +assert m8.flex_attn.block_mask is None +assert m8.flex_attn.cu_offsets is None +print(f" [8] after reset: offsets={m8.flex_attn.offsets} ✓") + +# Verify set_offsets_by_lens changes the offsets +m8.flex_attn.set_offsets_by_lens([16, 60]) +assert m8.flex_attn.offsets == [0, 16, 76], f"offsets={m8.flex_attn.offsets}" +_reset_flex_attn(m8, "m8") +assert m8.flex_attn.offsets is None +print(" [8] set_offsets_by_lens → reset cycle ✓") +print("[8] flex_attn semantics sanity PASSED (mock) ✓") + +# ========================================================================= +# Patch (9): logp/token reshape consistency +# ========================================================================= +print("\n[9] logp/token reshape consistency …") +import math as _math + +T9, H9, W9 = 3, 4, 5 +N9, B9, K9 = T9 * H9 * W9, 1, K + +torch.manual_seed(99) +z9 = torch.randn(B9, N9, K9) +p9 = F.softmax(z9 / 1.0, dim=-1) # [1, 60, K] + +x_hat_flat = torch.multinomial(p9.view(-1, K9), 1) # [N9, 1] +x_hat_1d = x_hat_flat.view(B9, N9) # [1, 60] +x_hat_4d = x_hat_1d.view(B9, T9, H9, W9) # [1, 3, 4, 5] + +# reshape round-trip +x_hat_back = x_hat_4d.view(B9, N9) +assert torch.equal(x_hat_1d, x_hat_back), "reshape round-trip FAILED" + +# logp +logp_all = p9.clamp(1e-8).log().gather(-1, x_hat_1d.unsqueeze(-1)).squeeze(-1) # [1, 60] +logp_sum = logp_all.sum(-1) + +# 10 spot-checks +torch.manual_seed(7) +positions = torch.randperm(N9)[:10].tolist() +for pos in positions: + tok_id = x_hat_1d[0, pos].item() + logp_man = _math.log(max(p9[0, pos, tok_id].item(), 1e-8)) + logp_gat = logp_all[0, pos].item() + diff = abs(logp_man - logp_gat) + assert diff < 1e-6, f"pos={pos} tok={tok_id} diff={diff:.2e}" + +print( + f" [9] T={T9},H={H9},W={W9} N={N9} K={K9} " + f"reshape ✓ 10 logp spots ✓ logp_sum={logp_sum.item():.3f}" +) +print("[9] logp/token reshape consistency PASSED ✓") + +# ========================================================================= +# Summary +# ========================================================================= +print("\n" + "=" * 70) +print("ALL 9 PATCHES PASSED ✓") +print("=" * 70) +print(""" +Patch summary: + (1) Batch-concat [2B]: single forward = identical results, half the calls ✓ + (2) reward/adv detach: no student grad, REINFORCE still flows via logp ✓ + (3) float32+log_softmax: KL≥0, KL(p,p)≈0, stable with large logits ✓ + (3b) guided logits: per-sample trunc, finite, explosion guard ✓ + (4) Separate loss log: loss_aux_cond/uncond + loss_kd_cond/uncond ✓ + (5) use_guided [B]: per-sample Bernoulli, correct warmup ramp ✓ + (6) flex_attn: probe returns None/object, reset clears all fields ✓ + (7) extract_visual_logits: D==K, D>K, full-vocab paths all verified ✓ + (8) flex_attn semantics: reset/set cycle correct (no real model needed) ✓ + (9) logp/token reshape: round-trip exact, 10 logp spot-checks < 1e-6 ✓ +""") diff --git a/URSA/scripts/train.py b/URSA/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2a402bd255cd7bf2e88d2e3cefe2afcd36b83f1d --- /dev/null +++ b/URSA/scripts/train.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------ +"""Train a diffnext model.""" + +import json +import os + +from diffnext.engine.train_engine import Trainer +from diffnext.engine.train_engine import engine_utils +from diffnext.utils import accelerate_utils +from diffnext.utils import omegaconf_utils + + +def prepare_checkpoints(config): + """Prepare checkpoints for model resuming. + + Args: + config (omegaconf.DictConfig) + The model config. + """ + config.experiment.setdefault("resume_from_checkpoint", "") + ckpt_dir = os.path.abspath(os.path.join(config.experiment.output_dir, "checkpoints")) + resume_iter, _ = 0, os.makedirs(ckpt_dir, exist_ok=True) + if config.experiment.resume_from_checkpoint == "latest": + ckpts = [_ for _ in os.listdir(ckpt_dir) if _.startswith("checkpoint-")] + if ckpts: + resume_iter, ckpt = sorted((int(_.split("-")[-1]), _) for _ in ckpts)[-1] + config.experiment.resume_from_checkpoint = os.path.join(ckpt_dir, ckpt) + elif config.experiment.resume_from_checkpoint: + resume_iter = int(os.path.split(config.experiment.resume_from_checkpoint).split("-")[-1]) + config.experiment.resume_iter = resume_iter + if resume_iter and not hasattr(config.model, "lora"): # Override the pretrained path. + config.pipeline.paths.pretrained_path = config.experiment.resume_from_checkpoint + + +def prepare_datasets(config, accelerator): + """Prepare datasets for model training. + + Args: + config (omegaconf.DictConfig) + The model config. + accelerator (accelerate.Accelerator) + The accelerator instance. + """ + dataset = config.train_dataloader.params.dataset + metadata = json.load(open(os.path.join(dataset, "METADATA"))) + config.train_dataloader.params.max_examples = metadata["entries"] + if "batch_size" in metadata: + batch_size = metadata["batch_size"][accelerator.process_index] + bucket_dataset = dataset + "/" + str(accelerator.process_index).zfill(3) + config.train_dataloader.params.dataset = bucket_dataset + config.train_dataloader.params.batch_size = config.training.batch_size = batch_size + if "num_metrics" in metadata: + config.training.num_metrics = metadata["num_metrics"] + elif "shard_id" not in config.train_dataloader.params: + # By default, we use dataset shards across all processes. + config.train_dataloader.params.update(accelerate_utils.get_ddp_shards(accelerator)) + + +def run_train(config, accelerator, logger): + """Start a model training task. + + Args: + config (omegaconf.DictConfig) + The model config. + accelerator (accelerate.Accelerator) + The accelerator instance. + logger (logging.Logger) + The logger instance. + """ + trainer = Trainer(config, accelerator, logger) + if accelerator.is_main_process: # Configs have already been determined. + config_path = os.path.join(config.experiment.output_dir, "config.yaml") + omegaconf_utils.save_config(config, config_path) + logger.info("#Params: %.2fM" % engine_utils.count_params(trainer.model)) + logger.info("Start training...") + trainer.train_loop() + trainer.ema.update(trainer.model) if trainer.ema else None + trainer.save() + + +def main(): + """Main entry point.""" + config = omegaconf_utils.get_config() + accelerator = accelerate_utils.build_accelerator(config, log_with="wandb") + accelerate_utils.build_wandb(config, accelerator=accelerator) + logger = accelerate_utils.set_logger(config.experiment.output_dir, accelerator=accelerator) + device_seed = config.training.seed + accelerator.process_index + config.training.gpu_id, config.training.seed = accelerator.device.index, device_seed + engine_utils.manual_seed(config.training.seed, (config.training.gpu_id, device_seed)) + prepare_checkpoints(config), prepare_datasets(config, accelerator) + logger.info(f"Config:\n{omegaconf_utils.config_to_yaml(config)}") + run_train(config, accelerator, logger) + + +if __name__ == "__main__": + main() diff --git a/URSA/scripts/train_distill_dimo.py b/URSA/scripts/train_distill_dimo.py new file mode 100644 index 0000000000000000000000000000000000000000..529e538dcdd043504fa1dc393357c6b73cdca6c0 --- /dev/null +++ b/URSA/scripts/train_distill_dimo.py @@ -0,0 +1,1293 @@ +#!/usr/bin/env python3 +# ------------------------------------------------------------------------ +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------ +"""URSA one-step distillation trainer (DiMO-style), 8-GPU distributed. + +Verified native inference regime (from A/B testing — ground truth): + height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. + no_cfg (guidance_scale=1) is NOT a valid baseline for this URSA checkpoint. + Defaults in configs/distill_dimo.yaml are aligned to this regime. + +Launch command: + + accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \\ + --machine_rank 0 --num_machines 1 --num_processes 8 \\ + scripts/train_distill_dimo.py \\ + config="./configs/distill_dimo.yaml" \\ + experiment.output_dir="./experiments/distill_dimo" \\ + distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\ + distill.prompt_source="/data/Koala_36M_*.csv" \\ + distill.batch_size_per_gpu=1 + +Smoke test (single-GPU, 50 steps): + + accelerate launch --num_processes 1 \\ + scripts/train_distill_dimo.py \\ + config="./configs/distill_dimo.yaml" \\ + experiment.output_dir="./experiments/smoke" \\ + distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\ + distill.prompt_source="prompts.txt" \\ + training.max_train_steps=50 + +Algorithm summary (9 stages per iteration) +------------------------------------------ +Stage 1 Tokenize → txt_ids [B, L] (CPU in worker, moved to GPU in run_step) +Stage 2 x_init ~ Uniform(K) (+ p_init mixing from x_hat_prev) +Stage 3 no_grad student(x_init) → x_hat [B, N], logp for PG +Stage 4 x_t = scheduler.add_noise(x_hat_4d, t) [B,T,H,W], long +Stage 5 no_grad teacher(x_t) → z_T_cond [B,N,K] (+ uncond if CFG) +Stage 6 aux update × fake_rounds: Jeffrey(z_T_target, z_A_cond).backward() +Stage 7 student KD forward on x_t → z_S_cond [B,N,K] +Stage 8 reward = -KL(z_T_cond, z_S_cond) [detached]; adv = reward - baseline_ema +Stage 9 Two-backward: + 9a _no_sync_backward(lambda_kd * loss_kd) [frees KD graph] + 9b accelerator.backward(lambda_pg * loss_pg - lambda_ent * H_mean) + opt_student.step() +""" + +import collections +import copy +import os +import sys +from typing import Optional + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +from diffnext.engine import engine_utils +from diffnext.engine.lr_scheduler import CosineLR +from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import ( + URSADistillDiMOPipeline, + _get_logits, + _stable_kl, + _stable_jeffrey, + _build_guided_logits, + _cfg_warmup_prob, + _no_sync_backward, + _reset_flex_attn, + VERIFIED_NATIVE_DEFAULTS, + check_verified_regime, +) +from diffnext.utils import accelerate_utils +from diffnext.utils import omegaconf_utils +from diffnext.utils import profiler +from src.distill.prompt_dataset import ( + CSVSpec, + InfiniteDataLoader, + PromptDataset, + make_collate_fn, +) + + +# --------------------------------------------------------------------------- +# DistillTwinModel — single nn.Module wrapping student + aux for DeepSpeed +# --------------------------------------------------------------------------- + + +class DistillTwinModel(torch.nn.Module): + """Wrapper that holds both student and aux as sub-modules. + + DeepSpeed (via Accelerate) only allows a single model in + ``accelerator.prepare()``. This container satisfies that constraint + while keeping student and aux as separately addressable sub-modules + with independent param groups. + """ + + def __init__(self, student: torch.nn.Module, aux: torch.nn.Module): + super().__init__() + self.student = student + self.aux = aux + + def forward(self, which: str, input_ids, rope_pos=None, **kwargs): + if which == "student": + return self.student(input_ids, rope_pos=rope_pos, **kwargs) + elif which == "aux": + return self.aux(input_ids, rope_pos=rope_pos, **kwargs) + else: + raise ValueError(f"DistillTwinModel: unknown sub-model '{which}'") + + +# --------------------------------------------------------------------------- +# DistillTrainer +# --------------------------------------------------------------------------- + +class DistillTrainer: + """Training orchestrator for on-policy one-step distillation. + + Reuses the same accelerate / logger / checkpoint API as + ``diffnext.engine.train_engine.Trainer`` so the distributed setup is + identical to the original training framework. + + Key differences from standard Trainer: + - Three models (teacher frozen, student + aux trainable) + - Student and aux are wrapped in a single ``DistillTwinModel`` so that + only one ``accelerator.prepare()`` call is needed (DeepSpeed requirement) + - One optimizer with two param_groups: [0]=student, [1]=aux + - LR schedulers for both param groups + - Two-backward strategy within each step + - PromptDataset (no video latents; prompt-only) + - Stage 6 freezes student / unfreezes aux; Stages 7-9 do the reverse + """ + + def __init__(self, config, accelerator, logger): + self.config = config + self.accelerator = accelerator + self.logger = logger + + cfg = config.distill + dtype = accelerate_utils.precision_to_dtype(config.training.mixed_precision) + self.device = accelerator.device + + # -------- Pipeline (teacher + student + aux) ---------------------- + logger.info(f"[init] Loading teacher from {cfg.teacher_ckpt} ...") + self.pipe = URSADistillDiMOPipeline( + teacher_ckpt=cfg.teacher_ckpt, + compute_dtype=dtype, + aux_noise_std=float(cfg.get("aux_noise_std", 0.0)), + ) + + # Move teacher to GPU (not prepared by accelerate — frozen). + self.pipe.teacher = self.pipe.teacher.to(self.device) + self.pipe.scheduler.to(device=self.device) + + # Compute latents shape from video geometry. + from src.distill.utils_ursa_inputs import compute_latents_shape + + # Read VAE strides from pipeline (falls back to URSA defaults 4/8). + vae_t = int(getattr(self.pipe, "vae_temporal_stride", 4)) + vae_s = int(getattr(self.pipe, "vae_spatial_stride", 8)) + self.latents_shape = compute_latents_shape( + cfg.num_frames, cfg.height, cfg.width, vae_t, vae_s + ) + T, H, W = self.latents_shape + self.N = T * H * W + self.K = self.pipe.codebook_size + logger.info( + f"[init] latents_shape=({T},{H},{W}) N={self.N} K={self.K} " + f"CFG={'ON' if cfg.enable_teacher_cfg else 'OFF'}" + ) + + # Pre-compute uncond token IDs (empty string, [1, L]) on CPU. + self.txt_uncond_base_cpu = self.pipe.tokenizer( + [""], + max_length=int(cfg.max_prompt_length), + padding="max_length", + padding_side="left", + truncation=True, + return_tensors="pt", + ).input_ids # [1, L] CPU + + # -------- Optimizers (before accelerate.prepare) ------------------ + # Single optimizer with two param groups: + # group[0] = student params, group[1] = aux params + opt_cls = torch.optim.AdamW + opt_s_params = dict( + lr=float(config.optimizer_student.params.lr), + betas=tuple(config.optimizer_student.params.get("betas", [0.9, 0.95])), + weight_decay=float(config.optimizer_student.params.get("weight_decay", 0.01)), + ) + opt_a_params = dict( + lr=float(config.optimizer_aux.params.lr), + betas=tuple(config.optimizer_aux.params.get("betas", [0.9, 0.95])), + weight_decay=float(config.optimizer_aux.params.get("weight_decay", 0.01)), + ) + + def _enable_gcpt(m): + # m.model.layers 是 Qwen3Model 的层列表 + for layer in m.model.layers: + layer.gradient_checkpointing = True + layer.self_attn.gradient_checkpointing = True + layer.mlp.gradient_checkpointing = True + + _enable_gcpt(self.pipe.student) + _enable_gcpt(self.pipe.aux) + + # -------- 断点续传:在 ZeRO-3 切分参数前加载权重 ------------------ + self.global_step = int(config.experiment.get("resume_iter", 0)) + if self.global_step > 0: + ckpt_dir = os.path.join( + config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}" + ) + if os.path.exists(ckpt_dir): + logger.info(f"[Resume] 正在从 {ckpt_dir} 恢复 Student 和 Aux 的权重...") + # 必须在 map_location="cpu" 下加载,防止爆显存,随后 prepare 会自动分配 + self.pipe.student.load_state_dict(torch.load(os.path.join(ckpt_dir, "student.pt"), map_location="cpu")) + self.pipe.aux.load_state_dict(torch.load(os.path.join(ckpt_dir, "aux.pt"), map_location="cpu")) + else: + logger.warning(f"[Resume] 找不到检查点 {ckpt_dir},将从随机初始状态起步!") + + # -------- Wrap student + aux into a single DistillTwinModel -------- + twin_model = DistillTwinModel(self.pipe.student, self.pipe.aux) + + opt_raw = opt_cls([ + {"params": list(self.pipe.student.parameters()), **opt_s_params}, + {"params": list(self.pipe.aux.parameters()), **opt_a_params}, + ]) + + # -------- accelerate.prepare: single model + single optimizer ------ + # Teacher is NOT prepared (frozen; no grad sync needed). + self.model, self.optimizer = accelerator.prepare(twin_model, opt_raw) + + # LR schedulers (step() called manually at end of each step). + self.scheduler_s = CosineLR( + lr_max=float(config.optimizer_student.params.lr), + lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)), + max_steps=int(config.training.max_train_steps), + warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)), + ) + self.scheduler_a = CosineLR( + lr_max=float(config.optimizer_aux.params.lr), + lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)), + max_steps=int(config.training.max_train_steps), + warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)), + ) + + # -------- Dataset / DataLoader ------------------------------------ + dataloader_cfg = config.get("prompt_dataloader", {}) + dataset = PromptDataset( + prompt_source=str(cfg.prompt_source), + shuffle_files=bool(dataloader_cfg.get("shuffle_files", True)), + shuffle_buffer=int(dataloader_cfg.get("shuffle_buffer", 0)), + seed=int(config.training.seed), + infinite=True, + csv=CSVSpec(caption_field=str(dataloader_cfg.get("caption_field", "caption"))), + ) + + # collate_fn: tokenize on CPU (no CUDA in workers). + collate_fn = make_collate_fn( + self.pipe.tokenizer, + max_prompt_length=int(cfg.max_prompt_length), + device=torch.device("cpu"), # CPU output — moved to GPU in run_step + ) + + loader = DataLoader( + dataset, + batch_size=int(cfg.batch_size_per_gpu), + shuffle=False, # IterableDataset: no shuffle flag + drop_last=True, + num_workers=int(dataloader_cfg.get("num_workers", 2)), + collate_fn=collate_fn, + pin_memory=True, + ) + # DataLoader is NOT prepared by accelerate because PromptDataset + # handles per-rank file sharding internally via torch.distributed. + self._inf_loader = InfiniteDataLoader(loader) + + # -------- Training state ------------------------------------------ + # self.global_step = int(config.experiment.get("resume_iter", 0)) + self.baseline_ema: float = 0.0 + self.x_hat_prev: Optional[torch.Tensor] = None + self.metrics = collections.OrderedDict() + + # -------- Verified regime validation -------------------------------- + native = VERIFIED_NATIVE_DEFAULTS + is_native = check_verified_regime( + height=int(cfg.height), + width=int(cfg.width), + num_frames=int(cfg.num_frames), + guidance_scale=float(cfg.teacher_cfg_scale) if cfg.enable_teacher_cfg else None, + label="train", + ) + logger.info( + f"[init] verified_native_regime={is_native} " + f"geometry=({cfg.num_frames}×{cfg.height}×{cfg.width}) " + f"teacher_cfg_scale={cfg.teacher_cfg_scale if cfg.enable_teacher_cfg else 'OFF'}" + ) + if not cfg.enable_teacher_cfg: + logger.warning( + "[WARN] Teacher CFG is DISABLED. no_cfg is known to produce " + "blank/blurry output for this URSA checkpoint. " + "Distillation without CFG is unlikely to produce useful results." + ) + elif float(cfg.teacher_cfg_scale) != native["guidance_scale"]: + logger.warning( + f"[WARN] teacher_cfg_scale={cfg.teacher_cfg_scale} differs from " + f"the verified working value ({native['guidance_scale']}). " + "Outputs may deviate from the official inference working point." + ) + + logger.info( + f"[init] student params: {engine_utils.count_params(self.pipe.student):.2f}M" + ) + logger.info( + f"[init] max_train_steps={config.training.max_train_steps} " + f"batch_size_per_gpu={cfg.batch_size_per_gpu} " + f"num_processes={accelerator.num_processes}" + ) + + # ----------------------------------------------------------------------- + # run_step: Stages 1-9 + # ----------------------------------------------------------------------- + + def run_step(self, step: int) -> dict: + """Execute one distillation step (Stages 1-9).""" + cfg = self.config.distill + T, H, W = self.latents_shape + N, K = self.N, self.K + device = self.device + stats = {"step": step} + + timer = profiler.Timer().tic() + + # Update LR from cosine schedulers. + # param_groups[0] = student, param_groups[1] = aux + lr_s = self.scheduler_s.get_lr() + lr_a = self.scheduler_a.get_lr() + stats["lr_student"] = lr_s + stats["lr_aux"] = lr_a + self.optimizer.param_groups[0]["lr"] = lr_s + self.optimizer.param_groups[1]["lr"] = lr_a + + # ---------------------------------------------------------------- + # Stage 1: Get tokenised batch (CPU → GPU) + # ---------------------------------------------------------------- + txt_ids = next(self._inf_loader) # [B, L] CPU tensor + txt_ids = txt_ids.to(device, non_blocking=True) + B = txt_ids.size(0) + + txt_uncond = None + if cfg.enable_teacher_cfg: + txt_uncond = self.txt_uncond_base_cpu.expand(B, -1).to(device) + + # # ---------------------------------------------------------------- + # # Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing + # # ---------------------------------------------------------------- + # x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) + # if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0: + # n_mix = max(1, int(B * float(cfg.p_init_mix_ratio))) + # x_init[:n_mix] = self.pipe.corrupt_tokens( + # self.x_hat_prev[:n_mix], r=float(cfg.p_mix_corrupt_frac) + # ) + # ---------------------------------------------------------------- + # Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing + # ---------------------------------------------------------------- + x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) + + # 修复:使用概率触发,确保小 Batch 时模型依然能充分学习处理纯噪声 + if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0: + if torch.rand(1).item() < float(cfg.p_init_mix_ratio): + # 如果触发,只混合 batch 里的第一个样本 + x_init[0] = self.pipe.corrupt_tokens( + self.x_hat_prev[0:1], r=float(cfg.p_mix_corrupt_frac) + ).squeeze(0) + + # ---------------------------------------------------------------- + # Stage 3: Student 1-step on x_init — no_grad (only sample x_hat) + # + # Gradient-enabled forward on x_init is deferred to Stage 9b so + # the KD computation graph (Stage 7, x_t) can be freed first. + # ---------------------------------------------------------------- + with torch.no_grad(): + ids_init, rpos_init, _ = self.pipe.build_inputs( + txt_ids, x_init, self.latents_shape + ) + logits_s_init = _get_logits( + self.model("student", ids_init, rope_pos=rpos_init) + ) + z_s = self.pipe.extract_logits(logits_s_init, N) # [B, N, K] + p_s = F.softmax(z_s / float(cfg.tau), dim=-1) # [B, N, K] + x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) # [B, N] + + # if step == 1: + # # 只抽 8 个 token 做 sum=1 检查,别全量 + # idx = torch.randint(0, N, (8,), device=device) + # p_err = (p_s[:, idx].sum(-1) - 1).abs().max().item() + # assert p_err < 1e-3, f"p_s subset not normalised: {p_err}" + del p_s, z_s, logits_s_init + + x_hat_4d = x_hat.view(B, T, H, W) + + # ---------------------------------------------------------------- + # Stage 4: Pseudo-intermediate x_t = add_noise(x_hat, t) + # ---------------------------------------------------------------- + t = self.pipe.sample_t_curriculum( + B, device, step, int(cfg.t_curriculum_steps) + ) # [B] float ∈ (0.05, 0.995) + with torch.no_grad(): + x_t = self.pipe.scheduler.add_noise(x_hat_4d, t) # [B,T,H,W] long + + # # ---------------------------------------------------------------- + # # Stage 5: Teacher forward — single [2B] forward when CFG enabled + # # ---------------------------------------------------------------- + # with torch.no_grad(): + # if cfg.enable_teacher_cfg: + # txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L] + # x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W] + # ids_dual, rpos_dual, _ = self.pipe.build_inputs( + # txt_dual, x_t_dual, self.latents_shape + # ) + # logits_T_dual = _get_logits( + # self.pipe.teacher(ids_dual, rope_pos=rpos_dual) + # ) + # z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K] + # z_T_cond, z_T_uncond = z_T_dual.chunk(2, dim=0) # [B,N,K] + + # del logits_T_dual, z_T_dual + # torch.cuda.empty_cache() + + # ids_t, rpos_t = ids_dual[:B], rpos_dual[:B] + # else: + # ids_t, rpos_t, _ = self.pipe.build_inputs( + # txt_ids, x_t, self.latents_shape + # ) + # logits_T = _get_logits( + # self.pipe.teacher(ids_t, rope_pos=rpos_t) + # ) + # z_T_cond = self.pipe.extract_logits(logits_T, N) # [B,N,K] + # z_T_uncond = None + # ids_dual, rpos_dual = ids_t, rpos_t + + # # CFG guided target with per-sample Bernoulli warmup. + # z_T_guided = None + # use_guided_ratio = 0.0 + # if cfg.enable_teacher_cfg: + # p_guided = _cfg_warmup_prob( + # step, + # float(cfg.teacher_cfg_prob), + # int(cfg.teacher_cfg_warmup_steps), + # ) + # use_guided = torch.rand(B, device=device) < p_guided # [B] bool + # use_guided_ratio = float(use_guided.float().mean().item()) + # z_T_guided = _build_guided_logits( + # z_T_cond, z_T_uncond, + # t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc), + # ) + # mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond) + # z_T_target = torch.where(mask, z_T_guided, z_T_cond.float()) + # else: + # z_T_target = z_T_cond + + # z_T_target = z_T_target.detach() # NO grad path to teacher + + # # # ---------------------------------------------------------------- + # # # Stage 6: Aux update — fake_rounds iterations + # # # + # # # Freeze student so only aux gets gradients. With a single + # # # DeepSpeed-wrapped optimizer this is the cleanest way to ensure + # # # only aux params are updated. + # # # ---------------------------------------------------------------- + # # raw_twin = self.accelerator.unwrap_model(self.model) + # # raw_twin.student.requires_grad_(False) + # # raw_twin.aux.requires_grad_(True) + + # # loss_aux_cond_last = torch.tensor(0.0, device=device) + # # loss_aux_uncond_last = torch.tensor(0.0, device=device) + # # loss_aux_cond_sample_last = None + + # # for _fr in range(int(cfg.fake_rounds)): + # # self.optimizer.zero_grad(set_to_none=True) + + # # if cfg.enable_teacher_cfg: + # # logits_A_dual = _get_logits( + # # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) + # # ) + # # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K] + # # z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) + + # # loss_aux_cond_sample = _stable_jeffrey( + # # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024 + # # ) # [B] + # # loss_aux_cond_v = loss_aux_cond_sample.mean() + # # loss_aux_uncond_v = _stable_jeffrey( + # # z_T_uncond.float().detach(), z_A_uncond, float(cfg.tau_kd),chunk_size=1024 + # # ).mean() + # # loss_aux_v = ( + # # loss_aux_cond_v + # # + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v + # # ) + # # else: + # # logits_A = _get_logits( + # # self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) + # # ) + # # z_A_cond = self.pipe.extract_logits(logits_A, N) + # # loss_aux_cond_sample = _stable_jeffrey( + # # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024 + # # ) + # # loss_aux_cond_v = loss_aux_cond_sample.mean() + # # loss_aux_uncond_v = torch.tensor(0.0, device=device) + # # loss_aux_v = loss_aux_cond_v + + # # self.accelerator.backward(loss_aux_v) + # # if float(cfg.grad_clip) > 0: + # # torch.nn.utils.clip_grad_norm_( + # # raw_twin.aux.parameters(), float(cfg.grad_clip) + # # ) + # # self.optimizer.step() + # # self.optimizer.zero_grad(set_to_none=True) + + # # loss_aux_cond_last = loss_aux_cond_v.detach() + # # loss_aux_uncond_last = loss_aux_uncond_v.detach() + # # loss_aux_cond_sample_last = loss_aux_cond_sample.detach() # [B] + + # # # ---------------------------------------------------------------- + # # # Stage 7: Student KD forward on x_t (with grad) + # # # + # # # Switch: freeze aux, unfreeze student for Stages 7-9. + # # # ---------------------------------------------------------------- + # # raw_twin.student.requires_grad_(True) + # # raw_twin.aux.requires_grad_(False) + # # self.optimizer.zero_grad(set_to_none=True) + + # # if cfg.enable_teacher_cfg: + # # logits_S_dual = _get_logits( + # # self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach()) + # # ) + # # z_S_dual = self.pipe.extract_logits(logits_S_dual, N) + # # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) + # # loss_kd_cond = _stable_kl( + # # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048 + # # ).mean() + # # loss_kd_uncond = _stable_kl( + # # z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=2048 + # # ).mean() + # # loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond + # # else: + # # logits_S = _get_logits( + # # self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) + # # ) + # # z_S_cond = self.pipe.extract_logits(logits_S, N) + # # loss_kd_cond = _stable_kl( + # # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048 + # # ).mean() + # # loss_kd_uncond = torch.tensor(0.0, device=device) + # # loss_kd = loss_kd_cond + + # # # ---------------------------------------------------------------- + # # # Stage 8: Reward + advantage (fully detached — no student grad) + # # # + # # # INVARIANT: reward and adv must never carry student gradients. + # # # ---------------------------------------------------------------- + # # if cfg.enable_teacher_cfg and cfg.reward_use_guided: + # # z_T_for_rew = z_T_target # already detached + # # else: + # # z_T_for_rew = z_T_cond.detach() + + # # # reward[b] = -KL(z_T_cond || z_S_cond) with BOTH inputs detached + # # with torch.no_grad(): + # # reward = -_stable_kl( + # # z_T_for_rew.detach(), z_S_cond.detach(), float(cfg.tau), chunk_size=1024 + # # ) # [B] + # # assert not reward.requires_grad, ( + # # "[BUG] reward.requires_grad=True — student grad leaked into reward. " + # # "z_S_cond must be detached before KL for reward." + # # ) + # # self.baseline_ema = ( + # # 0.99 * self.baseline_ema + 0.01 * float(reward.mean().item()) + # # ) + # # adv = (reward - self.baseline_ema).detach() # [B] + # # assert not adv.requires_grad, "[BUG] adv.requires_grad=True" + + # # # ---------------------------------------------------------------- + # # # Stage 9: Two-backward student update + # # # + # # # 9a) KD backward first — frees the KD graph to save memory. + # # # Uses no_sync() (no DDP all-reduce) so gradients are not + # # # double-reduced when the PG backward syncs in 9b. + # # # 9b) Fresh forward on x_init WITH grad → PG + entropy backward. + # # # DDP all-reduce happens here (normal backward). + # # # ---------------------------------------------------------------- + + # # # 9a: KD backward (no sync — first of two backwards) + # # _no_sync_backward( + # # self.accelerator, self.model, float(cfg.lambda_kd) * loss_kd + # # ) + + # # # 9b: Policy + entropy — fresh forward on x_init WITH grad + # # ids_init, rpos_init, _ = self.pipe.build_inputs( + # # txt_ids, x_init, self.latents_shape + # # ) + # # logits_s_pol = _get_logits( + # # self.model("student", ids_init, rope_pos=rpos_init) + # # ) + # # z_s_pol = self.pipe.extract_logits(logits_s_pol, N) # [B, N, K] + + # # logp_tok = F.log_softmax(z_s_pol / float(cfg.tau), dim=-1) # [B, N, K] + # # p_s_pol = logp_tok.exp() + + # # # per-token average log-prob (recommended over log-prob sum) + # # logp_sum = ( + # # logp_tok.gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) + # # ) # [B] + # # logp = logp_sum / N # [B] per-token logp + + # # H_mean = -(p_s_pol * logp_tok).sum(-1).mean() + + # # loss_pg = -(adv * logp).mean() + # # lambda_ent_eff = float(cfg.lambda_ent) * (1.0 + 2.0 * use_guided_ratio) + + # # # Second backward: DDP all-reduce happens here. + # # self.accelerator.backward( + # # float(cfg.lambda_pg) * loss_pg - lambda_ent_eff * H_mean + # # ) + + # # if float(cfg.grad_clip) > 0: + # # torch.nn.utils.clip_grad_norm_( + # # raw_twin.student.parameters(), float(cfg.grad_clip) + # # ) + # # self.optimizer.step() + + # # # Restore both sub-modules to trainable for next step. + # # raw_twin.student.requires_grad_(True) + # # raw_twin.aux.requires_grad_(True) + + # # # p_init mixing: store x_hat_4d (detached) for next step. + # # self.x_hat_prev = x_hat_4d.detach() + + # # ---------------------------------------------------------------- + # # Stage 6: Aux update — Fit sampled pseudo-target (x_hat) from student + # # ---------------------------------------------------------------- + # raw_twin = self.accelerator.unwrap_model(self.model) + # raw_twin.student.requires_grad_(False) + # raw_twin.aux.requires_grad_(True) + + # target_tokens = x_hat.detach() # [B, N] - 学生在 Stage 3 盲猜出来的画面 + + # for _fr in range(int(cfg.fake_rounds)): + # self.optimizer.zero_grad(set_to_none=True) + + # if cfg.enable_teacher_cfg: + # logits_A_dual = _get_logits( + # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) + # ) + # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K] + # z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) + + # # Aux 拟合学生的假 token (Cross Entropy) + # loss_aux_cond_v = F.cross_entropy( + # z_A_cond.reshape(B * N, K), + # target_tokens.reshape(B * N), + # reduction="mean", + # ) + # loss_aux_uncond_v = F.cross_entropy( + # z_A_uncond.reshape(B * N, K), + # target_tokens.reshape(B * N), + # reduction="mean", + # ) + # loss_aux_v = loss_aux_cond_v + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v + # else: + # logits_A = _get_logits( + # self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) + # ) + # z_A_cond = self.pipe.extract_logits(logits_A, N) + + # loss_aux_cond_v = F.cross_entropy( + # z_A_cond.reshape(B * N, K), + # target_tokens.reshape(B * N), + # reduction="mean", + # ) + # loss_aux_uncond_v = torch.tensor(0.0, device=device) + # loss_aux_v = loss_aux_cond_v + + # self.accelerator.backward(loss_aux_v) + + # if float(cfg.grad_clip) > 0: + # torch.nn.utils.clip_grad_norm_( + # raw_twin.aux.parameters(), float(cfg.grad_clip) + # ) + # self.optimizer.step() + + # loss_aux_cond_last = loss_aux_cond_v.detach() + + # # ---------------------------------------------------------------- + # # Stage 7 & 8: Student KD update & Aux Bridge (Gradient Injection) + # # ---------------------------------------------------------------- + # raw_twin.student.requires_grad_(True) + # raw_twin.aux.requires_grad_(False) + # self.optimizer.zero_grad(set_to_none=True) + + # # 7a. Student KD forward on x_t (保持原样) + # if cfg.enable_teacher_cfg: + # logits_S_dual = _get_logits( + # self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach()) + # ) + # z_S_dual = self.pipe.extract_logits(logits_S_dual, N) + # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) + + # # --- [新增] 立刻释放显存 --- + # del logits_S_dual, z_S_dual + + # loss_kd_cond = _stable_kl( + # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048 + # ).mean() + # loss_kd_uncond = _stable_kl( + # z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=256 #2048 + # ).mean() + # loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond + # else: + # logits_S = _get_logits( + # self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) + # ) + # z_S_cond = self.pipe.extract_logits(logits_S, N) + # loss_kd_cond = _stable_kl( + # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048 + # ).mean() + # loss_kd_uncond = torch.tensor(0.0, device=device) + # loss_kd = loss_kd_cond + + # # 7b. 获取 Aux 的预测 (无梯度) 作为计算桥梁 + # with torch.no_grad(): + # if cfg.enable_teacher_cfg: + # logits_A_dual = _get_logits( + # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) + # ) + # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) + # z_A_cond, _ = z_A_dual.chunk(2, dim=0) + + # # --- [新增] 立刻释放显存 --- + # del logits_A_dual, z_A_dual + # else: + # logits_A = _get_logits( + # self.model("aux", ids_t.detach(), rpos_t.detach()) + # ) + # z_A_cond = self.pipe.extract_logits(logits_A, N) + + # # 8. Student 对初始噪声 x_init 进行带梯度的前向传播 + # ids_init, rpos_init, _ = self.pipe.build_inputs( + # txt_ids, x_init, self.latents_shape + # ) + # logits_s_pol = _get_logits( + # self.model("student", ids_init, rope_pos=rpos_init) + # ) + # z_s_pol = self.pipe.extract_logits(logits_s_pol, N) + + # # --- 核心数学修正:将 Logits 转换为概率,防止梯度爆炸 --- + # p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1) + # p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1) + + # # 目标方向:Teacher 概率 - Aux 概率 (遵循论文公式推导) + # bridge_target = (p_T - p_A).detach() + + # # 利用 MSE Trick 强制注入梯度 + # loss_bridge = 0.5 * F.mse_loss( + # z_s_pol.float(), + # (z_s_pol.float() + bridge_target).detach() + # ) + + # # 9. 单次反向传播 (合并 KD 和 Bridge) + # # 借用原来的 lambda_pg 参数来控制 bridge 损失的权重 + # loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge + # self.accelerator.backward(loss_student) + + # if float(cfg.grad_clip) > 0: + # torch.nn.utils.clip_grad_norm_( + # raw_twin.student.parameters(), float(cfg.grad_clip) + # ) + # self.optimizer.step() + + # # 恢复两者的可训练状态 + # raw_twin.student.requires_grad_(True) + # raw_twin.aux.requires_grad_(True) + + # # --- 兼容原始日志输出的占位符 --- + # H_mean = torch.tensor(0.0, device=device) + # loss_pg = loss_bridge.detach() # 将 bridge 损失映射给 pg 显示 + # logp = torch.tensor(0.0, device=device) + # self.baseline_ema = 0.0 + + # ---------------------------------------------------------------- + # Stage 5: Teacher forward — 破除视图死锁,生成目标后立刻释放 + # ---------------------------------------------------------------- + with torch.no_grad(): + if cfg.enable_teacher_cfg: + txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L] + x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W] + ids_dual, rpos_dual, _ = self.pipe.build_inputs( + txt_dual, x_t_dual, self.latents_shape + ) + logits_T_dual = _get_logits( + self.pipe.teacher(ids_dual, rope_pos=rpos_dual) + ) + z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K] + + # 【显存救星 1】使用 .clone() 打断视图依赖,使得原始巨型张量可以被回收 + z_T_cond = z_T_dual[0:1].clone() # [1,N,K] + z_T_uncond = z_T_dual[1:2].clone() # [1,N,K] + ids_t, rpos_t = ids_dual[:B], rpos_dual[:B] + + # 立刻释放 17 GB 的双路缓冲 + del logits_T_dual, z_T_dual + torch.cuda.empty_cache() + else: + ids_t, rpos_t, _ = self.pipe.build_inputs(txt_ids, x_t, self.latents_shape) + logits_T = _get_logits(self.pipe.teacher(ids_t, rope_pos=rpos_t)) + z_T_cond = self.pipe.extract_logits(logits_T, N) + z_T_uncond = None + + # 计算 CFG guided target + z_T_guided = None + use_guided_ratio = 0.0 + if cfg.enable_teacher_cfg: + p_guided = _cfg_warmup_prob(step, float(cfg.teacher_cfg_prob), int(cfg.teacher_cfg_warmup_steps)) + use_guided = torch.rand(B, device=device) < p_guided + use_guided_ratio = float(use_guided.float().mean().item()) + + z_T_guided = _build_guided_logits( + z_T_cond, z_T_uncond, + t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc), + ) + mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond) + # 【显存救星 2】保持为 bf16 类型,避免膨胀到 8.5GB + z_T_target = torch.where(mask, z_T_guided, z_T_cond).to(dtype=z_T_cond.dtype).detach() + + # 立刻清理所有中间推导变量 + del z_T_cond, z_T_uncond, z_T_guided + torch.cuda.empty_cache() + else: + z_T_target = z_T_cond.detach() + + # ---------------------------------------------------------------- + # Stage 6: Aux update — 【显存救星 3】强行降维为单路前向传播 (Batch=1) + # ---------------------------------------------------------------- + raw_twin = self.accelerator.unwrap_model(self.model) + raw_twin.student.requires_grad_(False) + raw_twin.aux.requires_grad_(True) + + target_tokens = x_hat.detach() + + for _fr in range(int(cfg.fake_rounds)): + self.optimizer.zero_grad(set_to_none=True) + + # 只处理单路 ids_t,不处理 dual,砍掉 Aux 50% 显存! + logits_A = _get_logits( + self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) + ) + z_A_cond = self.pipe.extract_logits(logits_A, N) + + loss_aux_cond_v = F.cross_entropy( + z_A_cond.reshape(B * N, K), + target_tokens.reshape(B * N), + reduction="mean", + ) + + self.accelerator.backward(loss_aux_cond_v) + if float(cfg.grad_clip) > 0: + torch.nn.utils.clip_grad_norm_(raw_twin.aux.parameters(), float(cfg.grad_clip)) + self.optimizer.step() + + # 必须立刻释放 + del logits_A, z_A_cond + torch.cuda.empty_cache() + + loss_aux_cond_last = loss_aux_cond_v.detach() + + # ---------------------------------------------------------------- + # Stage 7 & 8: Student KD update & Aux Bridge + # ---------------------------------------------------------------- + raw_twin.student.requires_grad_(True) + raw_twin.aux.requires_grad_(False) + self.optimizer.zero_grad(set_to_none=True) + + # 7a. Student KD (强行降维为单路前向传播 Batch=1) + logits_S = _get_logits( + self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) + ) + z_S_cond = self.pipe.extract_logits(logits_S, N) + + # 使用 128 chunk size,确保极致安全 + loss_kd = _stable_kl( + z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=128 + ).mean() + + del logits_S, z_S_cond + torch.cuda.empty_cache() + + # 7b. 获取 Aux 的预测作为桥梁 (依然单路) + with torch.no_grad(): + logits_A = _get_logits( + self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) + ) + z_A_cond = self.pipe.extract_logits(logits_A, N) + + # 8. Student 对 x_init 进行前向传播 + ids_init, rpos_init, _ = self.pipe.build_inputs(txt_ids, x_init, self.latents_shape) + logits_s_pol = _get_logits( + self.model("student", ids_init, rope_pos=rpos_init) + ) + z_s_pol = self.pipe.extract_logits(logits_s_pol, N) + + # 【显存救星 4】在 bf16 精度下计算 Softmax 概率,防止 float32 炸存 + p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype) + p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype) + + bridge_target = (p_T - p_A).detach() + + # 拿到 bridge_target 后,前面所有百兆甚至 G 级的张量统统干掉 + del p_T, p_A, logits_A, z_A_cond, z_T_target + torch.cuda.empty_cache() + + # 伪梯度注入 + loss_bridge = 0.5* K * F.mse_loss( + z_s_pol.float(), + (z_s_pol.float() + bridge_target.float()).detach() + ) + + # 9. 统一反向传播 + loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge + self.accelerator.backward(loss_student) + + if float(cfg.grad_clip) > 0: + torch.nn.utils.clip_grad_norm_(raw_twin.student.parameters(), float(cfg.grad_clip)) + self.optimizer.step() + + # 恢复状态 + raw_twin.student.requires_grad_(True) + raw_twin.aux.requires_grad_(True) + + # 最后的清理 + del logits_s_pol, z_s_pol, bridge_target + torch.cuda.empty_cache() + + H_mean = torch.tensor(0.0, device=device) + loss_pg = loss_bridge.detach() + logp = torch.tensor(0.0, device=device) + self.baseline_ema = 0.0 + + # Advance LR schedulers. + self.scheduler_s.step() + self.scheduler_a.step() + + # ---------------------------------------------------------------- + # Step 1 sanity assertions (lightweight; runs only at step 1) + # ---------------------------------------------------------------- + # if step == 1: + # self._step1_assertions( + # x_init, ids_init, rpos_init, z_s, p_s, logp, + # z_T_cond, z_S_cond, x_t, B, T, H, W, + # ) + + # ---------------------------------------------------------------- + # Token-level collapse detection + # ---------------------------------------------------------------- + tok_entropy = self._token_entropy(x_hat) + if not hasattr(self, "_init_tok_entropy"): + self._init_tok_entropy = tok_entropy + collapse_frac = float(cfg.get("collapse_warn_frac", 0.2)) + if tok_entropy < collapse_frac * self._init_tok_entropy: + self.logger.warning( + f"[COLLAPSE] step={step} tok_H={tok_entropy:.3f} " + f"init={self._init_tok_entropy:.3f} " + f"ratio={tok_entropy / max(self._init_tok_entropy, 1e-8):.2f} " + f"< {collapse_frac}. Try increasing lambda_ent." + ) + + stats["time"] = timer.toc() + stats["metrics"] = collections.OrderedDict( + sorted( + { + "loss_aux_cond": float(loss_aux_cond_last.item()), + "loss_kd_cond": float(loss_kd.item()), + "loss_pg": float(loss_pg.item()), + "H_mean": float(H_mean.item()), + "tok_entropy": float(tok_entropy), + "mean_logp_tok": float(logp.mean().item()), + "baseline_ema": float(self.baseline_ema), + "use_guided_ratio": float(use_guided_ratio), + }.items() + ) + ) + return stats + + # ----------------------------------------------------------------------- + # Train loop + # ----------------------------------------------------------------------- + + def train_loop(self): + """Main training loop (mirrors diffnext.engine.train_engine.Trainer).""" + cfg_exp = self.config.experiment + max_steps = int(self.config.training.max_train_steps) + log_every = int(cfg_exp.log_every) + save_every = int(cfg_exp.save_every) + + self.global_step = int(self.config.experiment.get("resume_iter", 0)) + # Sync LR schedulers to resume step (set _step_count directly; + # CosineLR uses _step_count internally in get_decay()). + self.scheduler_s._step_count = self.global_step + self.scheduler_a._step_count = self.global_step + + # [可选补充] 如果是续传,让 accelerator 自动恢复被切分的 Optimizer 等状态 + if self.global_step > 0: + ckpt_dir = os.path.join(self.config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}") + if os.path.exists(ckpt_dir): + self.accelerator.load_state(ckpt_dir) + self.logger.info(f"✅ ZeRO-3 完整状态 (包含 Optimizer) 已从 {ckpt_dir} 恢复") + + timer = profiler.Timer() + self.logger.info( + f"[train] Starting from step {self.global_step} / {max_steps}" + ) + + while self.global_step < max_steps: + self.global_step += 1 + with timer.tic_and_toc(): + stats = self.run_step(self.global_step) + self._add_metrics(stats) + + if self.global_step % log_every == 0: + self._log_metrics(stats) + + if self.global_step % (10 * log_every) == 0: + self.logger.info( + profiler.get_progress(timer, self.global_step, max_steps) + ) + + if self.global_step % save_every == 0: + self.save(self.global_step) + + # Final log + save (only when loop ran at least one step). + if self.global_step > int(self.config.experiment.get("resume_iter", 0)): + self._log_metrics({**stats, "step": self.global_step}) # noqa: F821 + self.accelerator.wait_for_everyone() + self.save(self.global_step, suffix="final") + self.accelerator.end_training() + + # ----------------------------------------------------------------------- + # Checkpoint helpers + # ----------------------------------------------------------------------- + + # def save(self, step: int, suffix: str = None) -> None: + # """Save student + aux state_dicts (rank0 only). + + # Saved as: + # /checkpoints/checkpoint-/student.pt + # /checkpoints/checkpoint-/aux.pt + + # The student.pt can be used for inference by replacing the + # transformer weights in a URSAPipeline (see README). + # """ + # if not self.accelerator.is_main_process: + # return + + # folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}" + # ckpt_dir = os.path.join( + # self.config.experiment.output_dir, "checkpoints", folder + # ) + # os.makedirs(ckpt_dir, exist_ok=True) + + # raw_student = self.accelerator.unwrap_model(self.model).student + # raw_aux = self.accelerator.unwrap_model(self.model).aux + + # student_path = os.path.join(ckpt_dir, "student.pt") + # aux_path = os.path.join(ckpt_dir, "aux.pt") + + # torch.save(raw_student.state_dict(), student_path) + # torch.save(raw_aux.state_dict(), aux_path) + + # # Also save training state for resuming. + # state = { + # "global_step": step, + # "baseline_ema": self.baseline_ema, + # "optimizer": self.optimizer.state_dict(), + # } + # torch.save(state, os.path.join(ckpt_dir, "train_state.pt")) + # self.logger.info(f"[save] step={step} → {ckpt_dir}") + + def save(self, step: int, suffix: str = None) -> None: + """Save student + aux state_dicts (支持 DeepSpeed ZeRO-3 自动聚合).""" + + # ⚠️ 【极其重要】:get_state_dict 必须由所有 8 张卡共同执行! + # 绝对不能把它放在 is_main_process 判断的里面,否则会触发跨卡死锁! + full_state_dict = self.accelerator.get_state_dict(self.model) + + # 只有主进程(0号卡)负责把聚合好的完整参数写进硬盘 + if not self.accelerator.is_main_process: + return + + folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}" + ckpt_dir = os.path.join( + self.config.experiment.output_dir, "checkpoints", folder + ) + os.makedirs(ckpt_dir, exist_ok=True) + + # 从 TwinModel 的完整字典中,根据前缀拆分出 student 和 aux 的独立权重 + student_state = {k.replace("student.", ""): v for k, v in full_state_dict.items() if k.startswith("student.")} + aux_state = {k.replace("aux.", ""): v for k, v in full_state_dict.items() if k.startswith("aux.")} + + student_path = os.path.join(ckpt_dir, "student.pt") + aux_path = os.path.join(ckpt_dir, "aux.pt") + + torch.save(student_state, student_path) + torch.save(aux_state, aux_path) + + # 保存辅助状态 + state = { + "global_step": step, + "baseline_ema": self.baseline_ema, + } + torch.save(state, os.path.join(ckpt_dir, "train_state.pt")) + self.logger.info(f"[save] step={step} → {ckpt_dir} (ZeRO-3 Gathered)") + + # ----------------------------------------------------------------------- + # Logging helpers (same API as original Trainer) + # ----------------------------------------------------------------------- + + def _add_metrics(self, stats: dict) -> None: + for k, v in stats["metrics"].items(): + if k not in self.metrics: + self.metrics[k] = profiler.SmoothedValue() + self.metrics[k].update(v) + + def _log_metrics(self, stats: dict) -> None: + iter_template = "Iteration %d, lr_s=%.2e lr_a=%.2e, time=%.2fs" + self.logger.info( + iter_template + % ( + stats["step"], + stats.get("lr_student", 0.0), + stats.get("lr_aux", 0.0), + stats.get("time", 0.0), + ) + ) + metric_template = " Train %s: %s" + for k, v in self.metrics.items(): + self.logger.info(metric_template % (k, v)) + tracker_logs = {k: v.median for k, v in self.metrics.items()} + tracker_logs.update( + { + "lr_student": stats.get("lr_student", 0.0), + "time": stats.get("time", 0.0), + } + ) + self.accelerator.log(tracker_logs, step=stats["step"]) + self.metrics.clear() + + # ----------------------------------------------------------------------- + # Sanity checks (step 1 only) + # ----------------------------------------------------------------------- + + def _step1_assertions( + self, x_init, ids_init, rpos_init, z_s, p_s, logp, + z_T_cond, z_S_cond, x_t, B, T, H, W, + ) -> None: + """Shape / value-domain assertions (mirrors single-card script).""" + N, K = self.N, self.K + lm_vocab = self.pipe.teacher.config.lm_vocab_size + L_plus_N1 = ids_init.size(1) + txt_len = L_plus_N1 - (N + 1) + + assert x_init.dtype == torch.long + assert x_init.min() >= 0 and x_init.max() < K + + assert ids_init.shape == (B, L_plus_N1), ids_init.shape + txt_part = ids_init[:, :txt_len] + vis_part = ids_init[:, -N:] + assert (txt_part < lm_vocab).all(), "text tokens in visual range" + assert (vis_part >= lm_vocab).all(), "visual tokens not shifted" + assert (vis_part < lm_vocab + K).all(), "visual tokens exceed lm_vocab+K" + + assert rpos_init.shape == (B, L_plus_N1, 3), rpos_init.shape + assert z_s.shape == (B, N, K), z_s.shape + p_err = float((p_s.sum(-1) - 1).abs().max().item()) + assert p_err < 1e-3, f"p_s not normalised: max_dev={p_err:.2e}" + + assert not torch.isnan(logp).any(), "logp has NaN" + assert not torch.isinf(logp).any(), "logp has Inf" + assert x_t.min() >= 0 and x_t.max() < K + + assert z_T_cond.shape == z_S_cond.shape == (B, N, K), ( + f"z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}" + ) + + # Teacher has no grad. + teacher_grads = [ + p for p in self.pipe.teacher.parameters() if p.grad is not None + ] + assert len(teacher_grads) == 0, "teacher has grads — not frozen" + + # Student has grad (from PG backward). + raw_s = self.accelerator.unwrap_model(self.model).student + student_grad_norms = [ + float(p.grad.norm().item()) + for p in raw_s.parameters() + if p.grad is not None + ] + assert len(student_grad_norms) > 0, "student has NO grads — grad flow broken" + + # ########################## + # raw_t = self.pipe.teacher + # raw_s = self.accelerator.unwrap_model(self.model).student + + # # (a) 共享存储检查:零开销 + # pt0 = next(raw_t.parameters()) + # ps0 = next(raw_s.parameters()) + # self.logger.info(f"[assert] shared_storage={pt0.data_ptr() == ps0.data_ptr()}") + + # # (b) 参数差异:只采样前 4096 个元素,避免巨型临时张量 + # with torch.no_grad(): + # a = pt0.view(-1)[:4096].float() + # b = ps0.view(-1)[:4096].float() + # self.logger.info(f"[assert] param_delta_sample_max={float((a-b).abs().max().item()):.3e}") + + # # (c) logits 差异:只采样小子块(64 token × 256 vocab) + # with torch.no_grad(): + # idx_n = torch.randint(0, self.N, (64,), device=z_T_cond.device) + # idx_k = torch.randint(0, self.K, (256,), device=z_T_cond.device) + # subT = z_T_cond[0, idx_n][:, idx_k].float() + # subS = z_S_cond[0, idx_n][:, idx_k].float() + # self.logger.info(f"[assert] logits_delta_sub_max={float((subT-subS).abs().max().item()):.3e}") + # ########################### + + self.logger.info("[assert] Step-1 shape/grad assertions PASSED ✓") + self.logger.info( + f"[assert] z_T_cond shape={z_T_cond.shape} " + f"min={float(z_T_cond.min().item()):.3f} " + f"max={float(z_T_cond.max().item()):.3f}" + ) + self.logger.info( + f"[assert] z_S_cond shape={z_S_cond.shape} " + f"min={float(z_S_cond.min().item()):.3f} " + f"max={float(z_S_cond.max().item()):.3f}" + ) + + @staticmethod + def _token_entropy(x_hat: torch.Tensor) -> float: + """Histogram entropy of sampled token indices (collapse detection).""" + counts = x_hat.flatten().bincount(minlength=1).float() + p = counts / counts.sum() + p = p[p > 0] + return float(-(p * p.log()).sum().item()) + + +def main(): + """Entry point — identical pattern to scripts/train.py.""" + config = omegaconf_utils.get_config() + os.makedirs(config.experiment.output_dir, exist_ok=True) + + accelerator = accelerate_utils.build_accelerator(config) + accelerate_utils.build_wandb(config, accelerator=accelerator) + logger = accelerate_utils.set_logger( + config.experiment.output_dir, accelerator=accelerator + ) + + device_seed = int(config.training.seed) + accelerator.process_index + engine_utils.manual_seed(device_seed, (accelerator.device.index, device_seed)) + + if accelerator.is_main_process: + config_path = os.path.join(config.experiment.output_dir, "config.yaml") + omegaconf_utils.save_config(config, config_path) + + logger.info(f"Config:\n{omegaconf_utils.config_to_yaml(config)}") + + trainer = DistillTrainer(config, accelerator, logger) + trainer.train_loop() + + +if __name__ == "__main__": + main() diff --git a/URSA/scripts/train_onestep.md b/URSA/scripts/train_onestep.md new file mode 100644 index 0000000000000000000000000000000000000000..1ed5a35651f24426b8ea9da2d826a40d5222b96b --- /dev/null +++ b/URSA/scripts/train_onestep.md @@ -0,0 +1,78 @@ +代码解读(关键设计决策) +utils_ursa_inputs.py +build_ursa_inputs(transformer, txt_ids, visual_tokens, latents_shape, device) +严格复刻 URSAPipeline.__call__ 的 token 拼接逻辑: +img_ids = pad(latents_flat + lm_vocab_size, (1,0), value=bov_token_id)input_ids = cat([txt_ids, img_ids], dim=1)blk_pos = flex_rope.get_pos(latents_shape, L)rope_pos = cat([txt_pos, blk_pos[0]]).unsqueeze(0).expand(B,-1,-1) +extract_visual_logits(logits, N, K) +坑 1 防护:z = logits[:, -(N+1):-1](causal slice),然后根据最后一维是否等于 K 决定是否再切 slice。 +sample_t_curriculum — 前 10k 步用 t = 1-(1-u)^2 偏大,之后恢复均匀采样。 +train_onestep_ursa_dimo.py 训练循环 +每一步的 9 个 stage 对应 DiMO 论文的完整流程: +Stage 操作 梯度 +1-2 tokenize + 采样 x_init (80% uniform / 20% corrupt) 无 +3 student 在 x_init 上 1-step forward → x_hat, logp, H ✅ student +4 add_noise(x_hat, t) → x_t 无(离散采样截断) +5 teacher 在 x_t → p_T 无 (no_grad) +6 aux 在 x_t → Jeffrey(p_T, p_A) → backward → aux update ✅ aux only +7 student 在 x_t → KL(p_T ‖ p_S_t) ✅ student +8 REINFORCE: r=-loss_aux, adv=r-EMA, loss_pg=-(adv·logp) ✅ student (via logp) +9 L_s = λ_pg·loss_pg + λ_kd·loss_kd - λ_ent·H → student update ✅ student +运行命令示例 +端到端冒烟测试(单卡,17帧256×256,2000步): +python scripts/train_onestep_ursa_dimo.py \ + --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ + --prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \ + --num_frames 17 --height 256 --width 256 \ + --batch_size 1 --num_steps 2000 \ + --log_every 50 --save_every 500 \ + --out_dir ./outputs/dimo_test + +评估(1-step student vs 25-step teacher): +python scripts/eval_onestep_ursa.py \ + --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ + --student_ckpt ./outputs/dimo_test/final/student.pt \ + --num_frames 17 --height 256 --width 256 \ + --teacher_steps 25 \ + --out_dir ./outputs/eval + +扩展到完整分辨率(49帧 320×512): +python scripts/train_onestep_ursa_dimo.py \ + --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B/ \ + --prompt_file /gfs/space/private/fengzl/World_Model/Koala-36M-v1/ \ + --num_frames 49 --height 320 --width 512 \ + --batch_size 2 --num_steps 50000 \ + --lambda_ent 0.01 --t_curriculum_steps 10000 \ + --mixed_precision bf16 --out_dir ./outputs/dimo_full + +三大稳定性机制(缺一不可) +t curriculum — 前 10k 步 t 偏大,teacher 分布更尖锐,KD 信号更强,避免早期 student 随机游走 +p_init mixing — 20% batch 用 corrupt(x_hat_prev, r=0.2),让 student 学会"一步修复" +熵正则 λ_ent — 初始 0.01,若检测到 tok_entropy 下降就升到 0.05 + + +8 卡启动命令 +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml --machine_rank 0 --num_machines 1 --num_processes 8 scripts/train_distill_dimo.py config=./configs/distill_dimo.yaml experiment.output_dir=./experiments/distill_dimo distill.teacher_ckpt=/gfs/space/private/fengzl/World_Model/URSA-1.7B distill.prompt_source=/gfs/space/private/fengzl/World_Model/Koala-36M-v1 distill.batch_size_per_gpu=1 + +Smoke Test(50 步,保存 checkpoint) +accelerate launch --num_processes 8 --mixed_precision bf16 \ + scripts/train_distill_dimo.py \ + config="./configs/distill_dimo.yaml" \ + experiment.output_dir="./experiments/smoke" \ + distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \ + distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1" \ + training.max_train_steps=50 \ + experiment.save_every=50 + + +加载 student.pt 做 1-step 推理 +from diffnext.pipelines import URSAPipelineimport torchpipe = URSAPipeline.from_pretrained( "/path/to/URSA-1.7B-IBQ1024", torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")# 替换 transformer 权重为 studentstate = torch.load("experiments/distill_dimo/checkpoints/final/student.pt", map_location="cuda")pipe.transformer.load_state_dict(state, strict=True)# 1-step 生成(num_inference_steps=1)frames = pipe( prompt="a dog running on a beach", height=256, width=256, num_frames=17, num_inference_steps=1, guidance_scale=3.0,).frames + + +最新 修改分辨率和cfg后 +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ + --machine_rank 0 --num_machines 1 --num_processes 8 \ + scripts/train_distill_dimo.py \ + config="./configs/distill_dimo.yaml" \ + experiment.output_dir="./experiments/distill_dimo" \ + distill.teacher_ckpt="/gfs/space/private/fengzl/World_Model/URSA-1.7B" \ + distill.prompt_source="/gfs/space/private/fengzl/World_Model/Koala-36M-v1" \ No newline at end of file diff --git a/URSA/scripts/train_onestep_ursa_dimo.py b/URSA/scripts/train_onestep_ursa_dimo.py new file mode 100644 index 0000000000000000000000000000000000000000..18a07d992f27f04350cc3c4b3368f44ab529e633 --- /dev/null +++ b/URSA/scripts/train_onestep_ursa_dimo.py @@ -0,0 +1,1303 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024-present, BAAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------- +"""URSA → URSA one-step distillation via Di[M]O-style on-policy training. + +Verified native inference regime (from A/B testing — ground truth): + height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. + no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint. + All defaults below align to this verified regime. + +Algorithm (9 stages per iteration) +------------------------------------ + teacher : frozen URSA — provides supervision at pseudo-intermediate x_t. + student : trainable copy — 1-step target. + aux : trainable copy — approximates teacher at x_t; reduces REINFORCE variance. + + Stage 1 : tokenise prompts (cond + uncond when CFG enabled) → txt_ids [B,L] + Stage 2 : sample x_init [B,T,H,W] ~ Uniform(K) (+ optional p_init mixing) + Stage 3 : student 1-step forward on x_init (cond only) → x_hat, logp, H + Stage 4 : pseudo-intermediate x_t = scheduler.add_noise(x_hat, t) + Stage 5 : teacher forward on x_t (CFG=7 dual-branch is the default) + Stage 6 : aux forward → Jeffrey KD + Stage 7 : student forward on x_t → KL KD + Stage 8 : reward = -KL(z_T_cond, z_S_cond) [detached] + Stage 9 : two-backward student update + +Usage: + # Smoke test (verified native regime): + python scripts/train_onestep_ursa_dimo.py \\ + --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \\ + --enable_teacher_cfg --teacher_cfg_scale 7.0 \\ + --num_frames 49 --height 320 --width 512 --dry_run + + # Full training: + python scripts/train_onestep_ursa_dimo.py \\ + --teacher_ckpt /path/to/URSA --prompt_file prompts.txt \\ + --enable_teacher_cfg --teacher_cfg_scale 7.0 \\ + --num_frames 49 --height 320 --width 512 \\ + --batch_size 1 --num_steps 10000 --out_dir ./outputs/dimo_cfg +""" + +import argparse +import copy +import json +import math +import os +import sys + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from diffnext.pipelines import URSAPipeline +from src.distill.prompt_dataset import InfiniteDataLoader, PromptDataset, make_collate_fn, CSVSpec +from src.distill.utils_ursa_inputs import ( + build_ursa_inputs, + compute_latents_shape, + corrupt_tokens, + extract_visual_logits, + sample_t_curriculum, +) + +def _get_logits(out): + if isinstance(out, (tuple, list)): + return out[0] + if hasattr(out, "sample"): + return out.sample + if hasattr(out, "logits"): + return out.logits + return out + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="URSA DiMO one-step distillation") + + # Model / data + p.add_argument("--teacher_ckpt", required=True) + p.add_argument("--prompt_file", required=True) + p.add_argument("--out_dir", default="./outputs/dimo") + + # Video geometry (verified native: 320×512×49) + p.add_argument("--num_frames", type=int, default=49) + p.add_argument("--height", type=int, default=320) + p.add_argument("--width", type=int, default=512) + p.add_argument("--max_prompt_length", type=int, default=320) + + # Training + p.add_argument("--batch_size", type=int, default=1) + p.add_argument("--num_steps", type=int, default=10_000) + p.add_argument("--lr_student", type=float, default=1e-5) + p.add_argument("--lr_aux", type=float, default=1e-5) + p.add_argument("--weight_decay", type=float, default=0.01) + p.add_argument("--grad_clip", type=float, default=1.0) + p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--log_every", type=int, default=50) + p.add_argument("--save_every", type=int, default=1000) + + # Loss weights + p.add_argument("--lambda_pg", type=float, default=1.0) + p.add_argument("--lambda_kd", type=float, default=0.5) + p.add_argument("--lambda_ent", type=float, default=0.01) + p.add_argument("--tau", type=float, default=1.0, help="Student sampling temperature") + p.add_argument("--tau_kd", type=float, default=1.0, help="KD softmax temperature") + + # ---- Teacher CFG (DiMO true_cfg style) ---------------------------- + p.add_argument("--enable_teacher_cfg", action="store_true", default=False, + help="Enable teacher-side CFG for KD target. " + "False → prior single-branch behavior (fallback).") + p.add_argument("--teacher_cfg_scale", type=float, default=7.0, + help="CFG scale s (verified working value=7)") + p.add_argument("--teacher_cfg_prob", type=float, default=1.0, + help="Max prob of using guided target per sample (after warmup)") + p.add_argument("--teacher_cfg_warmup_steps", type=int, default=2000, + help="Steps to ramp teacher_cfg_prob 0 → teacher_cfg_prob") + p.add_argument("--teacher_cfg_trunc", type=float, default=0.9, + help="t threshold: when t >= trunc, s=1. Set >=1.0 to disable.") + p.add_argument("--lambda_kd_uncond", type=float, default=0.3, + help="Weight for uncond-branch KD / aux loss") + p.add_argument("--reward_use_guided", action="store_true", default=False, + help="[RISKY] Use guided teacher logits for REINFORCE reward.") + # ---- Eval CFG (inference-time) ----------------------------------- + p.add_argument("--eval_cfg_scale", type=float, default=7.0) + p.add_argument("--use_cfg_eval", action="store_true", default=True) + + # DiMO extensions + p.add_argument("--use_surrogate_grad", action="store_true", + help="DiMO surrogate MSE trick applied to Stage-3 logits") + p.add_argument("--lambda_surr", type=float, default=1.0) + p.add_argument("--fake_rounds", type=int, default=1, + help="Aux updates per generator update (DiMO=2)") + + # Stability + p.add_argument("--t_curriculum_steps", type=int, default=10_000) + p.add_argument("--p_mix_corrupt_frac", type=float, default=0.2) + p.add_argument("--p_init_mix_ratio", type=float, default=0.2) + p.add_argument("--collapse_warn_frac", type=float, default=0.2) + + # Debug + p.add_argument("--dry_run", action="store_true", + help="Run 1 step + grad-flow check, then exit") + p.add_argument("--debug_dump", type=int, default=0, + help="Dump token histogram + x_hat every N steps (0=off)") + + p.add_argument("--device", type=int, default=0) + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Checkpoint +# --------------------------------------------------------------------------- + +def save_checkpoint(model, path: str, name: str = "student"): + os.makedirs(path, exist_ok=True) + ckpt_path = os.path.join(path, f"{name}.pt") + torch.save(model.state_dict(), ckpt_path) + print(f"[save] {ckpt_path}") + + +# --------------------------------------------------------------------------- +# Stable KL / Jeffrey divergence helpers (float32 + log_softmax) +# --------------------------------------------------------------------------- + +def _stable_kl(z_p: torch.Tensor, z_q: torch.Tensor, tau: float = 1.0) -> torch.Tensor: + """KL(p||q) from raw logits, float32 + log_softmax. → [B] (mean over N tokens). + + p = softmax(z_p/tau), q = softmax(z_q/tau) + KL(p||q) = sum_k p_k * (log p_k - log q_k) + + Both log_p and log_q are computed via log_softmax to avoid + log(softmax(...)) numerical issues. + """ + lp = F.log_softmax(z_p.float() / tau, dim=-1) # [B, N, K] + lq = F.log_softmax(z_q.float() / tau, dim=-1) # [B, N, K] + return (lp.exp() * (lp - lq)).sum(-1).mean(-1) # [B] + + +def _stable_jeffrey(z_p: torch.Tensor, z_q: torch.Tensor, tau: float = 1.0) -> torch.Tensor: + """Symmetric KL (Jeffrey) from logits, float32 + log_softmax. → [B].""" + return _stable_kl(z_p, z_q, tau) + _stable_kl(z_q, z_p, tau) + + +# --------------------------------------------------------------------------- +# Batch-concat input builder (ONE forward for cond + uncond) +# --------------------------------------------------------------------------- + +def _build_dual_inputs(teacher_ref, txt_cond, txt_uncond, x_t, latents_shape, device): + """Concatenate cond+uncond into a single [2B] forward-pass input. + + Returns (ids_dual [2B, L+N+1], rpos_dual [2B, L+N+1, 3], N). + After the forward: chunk(2, dim=0) → (z_cond [B], z_uncond [B]). + + All three models (teacher/aux/student) share the SAME ids_dual / rpos_dual + so the tokens are constructed only once per step. + """ + txt_dual = torch.cat([txt_cond, txt_uncond], dim=0) # [2B, L] + x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B, T, H, W] + return build_ursa_inputs(teacher_ref, txt_dual, x_t_dual, latents_shape, device) + + +# --------------------------------------------------------------------------- +# flex_attn probe / reset helpers +# --------------------------------------------------------------------------- + +def _probe_flex_attn(model, label: str = "") -> object: + """Return the FlexAttentionCausal2D object if present, else None.""" + return getattr(model, "flex_attn", None) + + +def _print_flex_attn_state(model, label: str): + fa = _probe_flex_attn(model, label) + if fa is None: + print(f" [flex_attn/{label}] not present on model") + return + print( + f" [flex_attn/{label}] offsets={fa.offsets!r} " + f"block_mask={'set' if fa.block_mask is not None else 'None'} " + f"cu_offsets={'set' if fa.cu_offsets is not None else 'None'}" + ) + + +def _reset_flex_attn(model, label: str = "", verbose: bool = False): + """Reset flex_attn to None offsets so standard causal attention is used. + + Our distillation training processes each sample independently (batch dim) + so block-packed attention (offsets != None) is not needed and must be cleared + to avoid cross-sample mask contamination. + """ + fa = _probe_flex_attn(model, label) + if fa is None: + return + old_offsets = fa.offsets + fa.offsets = None + fa.block_mask = None + fa.cu_offsets = None + if verbose: + print(f" [flex_attn/{label}] reset: was={old_offsets!r} → None (standard causal)") + + +# --------------------------------------------------------------------------- +# Teacher CFG target construction +# --------------------------------------------------------------------------- + +def _compute_cfg_scale(t: torch.Tensor, cfg_scale: float, trunc: float) -> torch.Tensor: + """Per-sample CFG scale [B]: s=cfg_scale when t < trunc, else s=1.""" + s = torch.full_like(t, cfg_scale) + if trunc < 1.0: + s = torch.where(t >= trunc, torch.ones_like(t), s) + return s + + +def _cfg_warmup_prob(step: int, cfg_prob: float, warmup_steps: int) -> float: + """Linear warmup: 0 → cfg_prob over warmup_steps steps.""" + if warmup_steps <= 0: + return cfg_prob + return cfg_prob * min(1.0, step / warmup_steps) + + +def _build_guided_logits( + z_T_cond: torch.Tensor, # [B, N, K] float32 + z_T_uncond: torch.Tensor, # [B, N, K] float32 + t: torch.Tensor, # [B] ∈ (0,1) + cfg_scale: float, + trunc: float, +) -> torch.Tensor: + """z_guided = z_uncond + s*(z_cond - z_uncond), per-sample s [B,1,1].""" + s = _compute_cfg_scale(t, cfg_scale, trunc).view(-1, 1, 1) # [B,1,1] + return z_T_uncond + s * (z_T_cond - z_T_uncond) # [B, N, K] + + +def _select_target( + z_guided: torch.Tensor, # [B, N, K] + z_cond: torch.Tensor, # [B, N, K] + use_guided: torch.Tensor, # [B] bool — per-sample selection +) -> torch.Tensor: + """Per-sample: z_guided where use_guided[b]=True, else z_cond.""" + mask = use_guided.view(-1, 1, 1).expand_as(z_cond) + return torch.where(mask, z_guided, z_cond) + + +# --------------------------------------------------------------------------- +# Gradient-flow debug +# --------------------------------------------------------------------------- + +def debug_grad_flow( + teacher, student, aux, + txt_cond, txt_uncond, x_t, latents_shape, device, K, N, tau, tau_kd, + enable_teacher_cfg, +): + """One fwd+bwd without optimizer.step(). + + Asserts: + - teacher: zero grads (frozen) + - aux: non-zero grads after loss_aux.backward() + - student: non-zero grads after loss_student.backward() + + All cond/uncond forwards are batch-concatenated per requirement (1). + """ + print("\n" + "=" * 64) + print("[grad_flow] Starting gradient flow debug …") + B = txt_cond.size(0) + + # -- Stage 3: student on x_init (cond only) ---------------------- + x_init_dbg = torch.randint(0, K, x_t.shape, device=device, dtype=torch.long) + ids_init, rpos_init, _ = build_ursa_inputs(teacher, txt_cond, x_init_dbg, latents_shape, device) + logits_s = student(ids_init, rope_pos=rpos_init).sample + z_s = extract_visual_logits(logits_s.float(), N, K) + p_s = F.softmax(z_s / tau, dim=-1) + x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) + logp = p_s.clamp(1e-8).log().gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) + H_mean = -(p_s * p_s.clamp(1e-8).log()).sum(-1).mean() + + # -- Stage 5: teacher forward — [2B] if CFG, else [B] ------------ + if enable_teacher_cfg and txt_uncond is not None: + ids_dual, rpos_dual, _ = _build_dual_inputs(teacher, txt_cond, txt_uncond, x_t, latents_shape, device) + with torch.no_grad(): + logits_T_dual = teacher(ids_dual, rope_pos=rpos_dual).sample.float() + z_T_dual = extract_visual_logits(logits_T_dual, N, K) + z_T_cond_dbg, z_T_uncond_dbg = z_T_dual.chunk(2, dim=0) + t_dbg = torch.full((B,), 0.5, device=device, dtype=torch.float32) + z_T_guided_dbg = _build_guided_logits( + z_T_cond_dbg.float(), z_T_uncond_dbg.float(), t_dbg, 3.0, 0.9) + z_T_target_dbg = z_T_guided_dbg.detach() + print(f" [grad_flow] z_T_cond shape={z_T_cond_dbg.shape} " + f"min={z_T_cond_dbg.min():.3f} max={z_T_cond_dbg.max():.3f}") + print(f" [grad_flow] z_T_uncond shape={z_T_uncond_dbg.shape} " + f"min={z_T_uncond_dbg.min():.3f} max={z_T_uncond_dbg.max():.3f}") + print(f" [grad_flow] z_T_guided shape={z_T_guided_dbg.shape} " + f"min={z_T_guided_dbg.min():.3f} max={z_T_guided_dbg.max():.3f}") + ids_t_ref = ids_dual[:B] + rpos_t_ref = rpos_dual[:B] + ids_fwd = ids_dual + rpos_fwd = rpos_dual + else: + ids_t_ref, rpos_t_ref, _ = build_ursa_inputs(teacher, txt_cond, x_t, latents_shape, device) + with torch.no_grad(): + logits_T = teacher(ids_t_ref, rope_pos=rpos_t_ref).sample.float() + z_T_target_dbg = extract_visual_logits(logits_T, N, K).detach() + ids_fwd = ids_t_ref + rpos_fwd = rpos_t_ref + + # Dual-path shape check (teacher vs student, same input) + with torch.no_grad(): + z_T_ref2 = extract_visual_logits( + teacher(ids_t_ref, rope_pos=rpos_t_ref).sample.float(), N, K) + z_S_ref2 = extract_visual_logits( + student(ids_t_ref.detach(), rope_pos=rpos_t_ref.detach()).sample.float(), N, K) + if z_T_ref2.shape != z_S_ref2.shape: + raise RuntimeError( + f"[FATAL] Dual-path shape mismatch: z_T={z_T_ref2.shape} z_S={z_S_ref2.shape}" + ) + print(f" [grad_flow] Dual-path check OK: shape={z_T_ref2.shape}") + + # -- Aux backward — [2B] if CFG, else [B] ------------------------- + logits_A = aux(ids_fwd.detach(), rope_pos=rpos_fwd.detach()).sample + if enable_teacher_cfg and txt_uncond is not None: + z_A_dual2 = extract_visual_logits(logits_A.float(), N, K) + z_A_cond_dbg, _ = z_A_dual2.chunk(2, dim=0) + else: + z_A_cond_dbg = extract_visual_logits(logits_A.float(), N, K) + loss_aux_sample = _stable_jeffrey(z_T_target_dbg, z_A_cond_dbg, tau_kd) + loss_aux = loss_aux_sample.mean() + loss_aux.backward() + + teacher_grads = [p.grad for p in teacher.parameters() if p.grad is not None] + aux_grads = [p.grad.norm().item() for p in aux.parameters() if p.grad is not None] + print(f" [grad_flow] teacher grads with non-None grad: {len(teacher_grads)} (must be 0)") + if aux_grads: + print(f" [grad_flow] aux grad norm min={min(aux_grads):.3e} " + f"mean={sum(aux_grads)/len(aux_grads):.3e} max={max(aux_grads):.3e}") + else: + print(" [grad_flow] ⚠️ aux has NO grads") + for param in aux.parameters(): + param.grad = None + + # -- Student backward — [B] (cond only for simplicity) ------------ + logits_S = student(ids_t_ref.detach(), rope_pos=rpos_t_ref.detach()).sample + z_S_cond = extract_visual_logits(logits_S.float(), N, K) + loss_kd = _stable_kl(z_T_target_dbg, z_S_cond, tau_kd).mean() + adv = (loss_aux_sample.detach() * 0 + 1.0) # dummy advantage (shape check) + assert not adv.requires_grad, "[BUG] adv must be detached" + loss_student = -(adv * logp).mean() + loss_kd - 0.01 * H_mean + loss_student.backward() + + student_grads = [p.grad.norm().item() for p in student.parameters() if p.grad is not None] + if student_grads: + print(f" [grad_flow] student grad norm min={min(student_grads):.3e} " + f"mean={sum(student_grads)/len(student_grads):.3e} " + f"max={max(student_grads):.3e}") + else: + print(" [grad_flow] ⚠️ student has NO grads — diagnosing:") + print(f" logp.requires_grad={logp.requires_grad}") + print(f" z_s.requires_grad={z_s.requires_grad}") + + assert len(teacher_grads) == 0, "teacher has grads — not frozen" + assert len(aux_grads) > 0, "aux has no grads after loss_aux.backward()" + assert len(student_grads) > 0, "student has no grads — grad flow broken" + + for m in (student, aux): + for param in m.parameters(): + param.grad = None + + print(" [grad_flow] All gradient assertions PASSED ✓") + print("=" * 64 + "\n") + + +# --------------------------------------------------------------------------- +# Main training loop +# --------------------------------------------------------------------------- + +def main(): + args = parse_args() + + device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu") + dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} + compute_dtype = dtype_map[args.mixed_precision] + torch.manual_seed(args.seed) + os.makedirs(args.out_dir, exist_ok=True) + + # -- Verified regime validation ---------------------------------------- + _NATIVE = dict(height=320, width=512, num_frames=49, guidance_scale=7.0) + is_native = ( + args.height == _NATIVE["height"] + and args.width == _NATIVE["width"] + and args.num_frames == _NATIVE["num_frames"] + ) + print(f"[init] verified_native_regime={is_native} " + f"geometry=({args.num_frames}×{args.height}×{args.width}) " + f"teacher_cfg_scale={args.teacher_cfg_scale if args.enable_teacher_cfg else 'OFF'}") + if not is_native: + print(f"[WARN] Current geometry ({args.num_frames}×{args.height}×{args.width}) " + f"is not the verified native URSA regime " + f"({_NATIVE['num_frames']}×{_NATIVE['height']}×{_NATIVE['width']}). " + "Distillation quality may degrade or become invalid.") + if not args.enable_teacher_cfg: + print("[WARN] Teacher CFG is DISABLED. no_cfg is known to produce " + "blank/blurry output for this URSA checkpoint. " + "Distillation without CFG is unlikely to produce useful results.") + elif args.teacher_cfg_scale != _NATIVE["guidance_scale"]: + print(f"[WARN] teacher_cfg_scale={args.teacher_cfg_scale} differs from " + f"the verified working value ({_NATIVE['guidance_scale']}).") + + if args.enable_teacher_cfg and args.reward_use_guided: + print("[WARN] --reward_use_guided is ON — can cause mode collapse, watch tok_entropy.") + + # -- Load pipeline --------------------------------------------------- + print(f"[init] Loading from {args.teacher_ckpt} …") + pipe = URSAPipeline.from_pretrained( + args.teacher_ckpt, torch_dtype=compute_dtype, trust_remote_code=True + ).to(device) + + tokenizer = pipe.tokenizer + scheduler = pipe.scheduler + scheduler.to(device=device) + + vae_t_stride = getattr(pipe.vae.config, "temporal_stride", 4) + vae_s_stride = getattr(pipe.vae.config, "spatial_stride", 8) + latents_shape = compute_latents_shape( + args.num_frames, args.height, args.width, vae_t_stride, vae_s_stride + ) + T, H, W = latents_shape + N = T * H * W + K = scheduler.codebook_size + print( + f"[init] latents_shape=({T},{H},{W}) N={N} K={K} " + f"CFG={'ON' if args.enable_teacher_cfg else 'OFF'}" + ) + + # -- Pre-compute uncond token IDs (empty string, [1, L]) -------------- + txt_uncond_base = tokenizer( + [""], max_length=args.max_prompt_length, padding="max_length", + padding_side="left", truncation=True, return_tensors="pt", + ).input_ids.to(device) # [1, L] + + # -- Three models ---------------------------------------------------- + teacher = pipe.transformer.eval().requires_grad_(False) + student = copy.deepcopy(teacher).train().requires_grad_(True) + aux = copy.deepcopy(teacher).train().requires_grad_(True) + + # -- flex_attn: reset offsets to None (standard causal attn) --------- + # Our training processes B independent sequences in a batch, so block-packed + # offsets are not needed and must be cleared before any forward call. + if args.dry_run: + print("[init] flex_attn state before reset:") + for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")): + _print_flex_attn_state(m, lbl) + for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")): + _reset_flex_attn(m, lbl, verbose=True) + if args.dry_run: + print("[init] flex_attn state after reset:") + for m, lbl in ((teacher, "teacher"), (student, "student"), (aux, "aux")): + _print_flex_attn_state(m, lbl) + + opt_student = torch.optim.AdamW( + student.parameters(), lr=args.lr_student, weight_decay=args.weight_decay + ) + opt_aux = torch.optim.AdamW( + aux.parameters(), lr=args.lr_aux, weight_decay=args.weight_decay + ) + + # -- Dataset ---------------------------------------------------------- + # dataset = PromptDataset(args.prompt_file, shuffle=True, seed=args.seed) + collate = make_collate_fn(tokenizer, args.max_prompt_length, device) + # loader = DataLoader( + # dataset, batch_size=args.batch_size, shuffle=True, + # drop_last=True, num_workers=0, collate_fn=collate, + # ) + dataset = PromptDataset( + args.prompt_file, + shuffle_files=True, + shuffle_buffer=50000, # 例如 50k buffer,够用且不占太多内存 + seed=args.seed, + infinite=True, + csv=CSVSpec(caption_field="caption"), # Koala 默认就是 caption + ) + + loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, # IMPORTANT for IterableDataset + drop_last=True, + num_workers=2, # 视 IO 调大 + collate_fn=collate, + pin_memory=True, + ) + inf_loader = InfiniteDataLoader(loader) + + # -- Pre-training sanity check --------------------------------------- + _sanity_check_forward(teacher, scheduler, latents_shape, device, K, args.dry_run) + + # -- Training state -------------------------------------------------- + baseline_ema: float = 0.0 + x_hat_prev = None + initial_tok_entropy: float = None + dump_dir = os.path.join(args.out_dir, "debug_dumps") if args.debug_dump > 0 else None + + num_steps = 1 if args.dry_run else args.num_steps + print(f"[train] {'DRY RUN' if args.dry_run else f'{num_steps} steps'} " + f"| CFG={args.enable_teacher_cfg}") + + for step in range(1, num_steps + 1): + + # ---------------------------------------------------------------- + # Stage 1: Tokenise → txt_cond [B, L], txt_uncond [B, L] + # ---------------------------------------------------------------- + txt_cond = next(inf_loader) # [B, L] + txt_cond = txt_cond.to(device, non_blocking=True) + B = txt_cond.size(0) + + txt_uncond = None + if args.enable_teacher_cfg: + txt_uncond = txt_uncond_base.expand(B, -1) # [B, L] + + # ---------------------------------------------------------------- + # Stage 2: x_init ~ Uniform(K) (+ optional p_init mixing) + # ---------------------------------------------------------------- + x_init = _sample_x_init(B, T, H, W, K, device, x_hat_prev, args) + + # ---------------------------------------------------------------- + # Stage 3: Student 1-step forward on x_init — COND only. + # + # Gradient needed: logp and H flow back through p_s → student. + # ---------------------------------------------------------------- + with torch.no_grad(): + ids_init, rpos_init, _ = build_ursa_inputs( + teacher, txt_cond, x_init, latents_shape, device) + logits_s_init = student(ids_init, rope_pos=rpos_init).sample # [B, L+N+1, D] + z_s = extract_visual_logits(logits_s_init.float(), N, K) # [B, N, K] + p_s = F.softmax(z_s / args.tau, dim=-1) # [B, N, K] + x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) # [B, N] + # logp = p_s.clamp(1e-8).log().gather( + # -1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) # [B] + # H_mean = -(p_s * p_s.clamp(1e-8).log()).sum(-1).mean() + x_hat_4d = x_hat.view(B, T, H, W) + + # ---------------------------------------------------------------- + # Stage 4: Pseudo-intermediate x_t + # ---------------------------------------------------------------- + t = sample_t_curriculum(B, device, step, warmup_steps=args.t_curriculum_steps) + with torch.no_grad(): + x_t = scheduler.add_noise(x_hat_4d, t) # [B, T, H, W], long + + # ---------------------------------------------------------------- + # Stage 5: Teacher forward — single [2B] forward when CFG enabled. + # + # ids_dual / rpos_dual are SHARED by teacher, aux, and student to + # avoid redundant input construction. + # ---------------------------------------------------------------- + with torch.no_grad(): + if args.enable_teacher_cfg: + # ONE [2B] forward = cond (first B) + uncond (last B) + ids_dual, rpos_dual, _ = _build_dual_inputs( + teacher, txt_cond, txt_uncond, x_t, latents_shape, device) + logits_T_dual = teacher(ids_dual, rope_pos=rpos_dual).sample.float() + z_T_dual = extract_visual_logits(logits_T_dual, N, K) # [2B, N, K] + z_T_cond, z_T_uncond = z_T_dual.chunk(2, dim=0) # [B, N, K] each + ids_t = ids_dual[:B] # cond half — alias (no copy) + rpos_t = rpos_dual[:B] + else: + ids_t, rpos_t, _ = build_ursa_inputs( + teacher, txt_cond, x_t, latents_shape, device) + logits_T = teacher(ids_t, rope_pos=rpos_t).sample.float() + z_T_cond = extract_visual_logits(logits_T, N, K) # [B, N, K] + z_T_uncond = None + ids_dual = ids_t + rpos_dual = rpos_t + + # -- CFG guided target (float32, per-sample Bernoulli) ---------- + z_T_guided = None + if args.enable_teacher_cfg: + z_T_cond_f = z_T_cond.float() + z_T_uncond_f = z_T_uncond.float() + z_T_guided = _build_guided_logits( + z_T_cond_f, z_T_uncond_f, t, + args.teacher_cfg_scale, args.teacher_cfg_trunc) + + # per-sample Bernoulli: use_guided[b] ~ Bernoulli(p_guided) + p_guided = _cfg_warmup_prob( + step, args.teacher_cfg_prob, args.teacher_cfg_warmup_steps) + use_guided = torch.rand(B, device=device) < p_guided # [B] bool + use_guided_ratio = use_guided.float().mean().item() + z_T_target = _select_target(z_T_guided, z_T_cond_f, use_guided) # [B, N, K] + else: + use_guided = torch.zeros(B, dtype=torch.bool, device=device) + use_guided_ratio = 0.0 + z_T_target = z_T_cond.float() + + # z_T_target is the KD target — must have no grad path to teacher + z_T_target = z_T_target.detach() + + # ---------------------------------------------------------------- + # Stage 6: Aux forward (fake_rounds) — single [2B] forward when CFG. + # ---------------------------------------------------------------- + loss_aux_cond_v_last = None + loss_aux_uncond_v_last = None + loss_aux_cond_sample_last = None + + for _fr in range(args.fake_rounds): + opt_aux.zero_grad() + + if args.enable_teacher_cfg: + # ONE [2B] forward: cond+uncond in one shot + logits_A_dual = aux(ids_dual.detach(), rope_pos=rpos_dual.detach()).sample + z_A_dual = extract_visual_logits(logits_A_dual.float(), N, K) # [2B, N, K] + z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) + + # Cond: Jeffrey(z_T_target, z_A_cond) + loss_aux_cond_sample = _stable_jeffrey(z_T_target, z_A_cond, args.tau_kd) # [B] + loss_aux_cond_v = loss_aux_cond_sample.mean() + + # Uncond: Jeffrey(z_T_uncond, z_A_uncond) + z_T_uncond_det = z_T_uncond.float().detach() + loss_aux_uncond_sample = _stable_jeffrey(z_T_uncond_det, z_A_uncond, args.tau_kd) + loss_aux_uncond_v = loss_aux_uncond_sample.mean() + + loss_aux_v = loss_aux_cond_v + args.lambda_kd_uncond * loss_aux_uncond_v + else: + logits_A = aux(ids_t.detach(), rope_pos=rpos_t.detach()).sample + z_A_cond = extract_visual_logits(logits_A.float(), N, K) + + loss_aux_cond_sample = _stable_jeffrey(z_T_target, z_A_cond, args.tau_kd) # [B] + loss_aux_cond_v = loss_aux_cond_sample.mean() + loss_aux_uncond_v = torch.tensor(0.0, device=device) + loss_aux_v = loss_aux_cond_v + + loss_aux_v.backward() + if args.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(aux.parameters(), args.grad_clip) + opt_aux.step() + # make sure aux grads are cleared and no graph is retained + for p in aux.parameters(): + p.grad = None + + loss_aux_cond_v_last = loss_aux_cond_v.detach() + loss_aux_uncond_v_last = loss_aux_uncond_v.detach() + loss_aux_cond_sample_last = loss_aux_cond_sample.detach() # [B] + + # # ---------------------------------------------------------------- + # # Stage 7: Student KD forward on x_t — single [2B] when CFG. + # # Dual-path consistency check included. + # # ---------------------------------------------------------------- + # if args.enable_teacher_cfg: + # # ONE [2B] forward + # logits_S_dual = student(ids_dual.detach(), rope_pos=rpos_dual.detach()).sample + # z_S_dual = extract_visual_logits(logits_S_dual.float(), N, K) # [2B, N, K] + # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) + # else: + # logits_S = student(ids_t.detach(), rope_pos=rpos_t.detach()).sample + # z_S_cond = extract_visual_logits(logits_S.float(), N, K) # [B, N, K] + # z_S_uncond = None + + # # Dual-path shape consistency check + # if z_T_cond.shape != z_S_cond.shape: + # raise RuntimeError( + # f"[FATAL] Dual-path shape mismatch: " + # f"z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape} — " + # "vocab slicing inconsistency." + # ) + + # # KD losses (from raw logits, float32 + log_softmax) + # loss_kd_cond = _stable_kl(z_T_target, z_S_cond, args.tau_kd).mean() + # loss_kd_uncond_v = torch.tensor(0.0, device=device) + + # if args.enable_teacher_cfg and z_S_uncond is not None: + # z_T_uncond_det2 = z_T_uncond.float().detach() + # loss_kd_uncond_v = _stable_kl(z_T_uncond_det2, z_S_uncond, args.tau_kd).mean() + + # loss_kd = loss_kd_cond + args.lambda_kd_uncond * loss_kd_uncond_v + + # # ---------------------------------------------------------------- + # # Stage 8: REINFORCE reward + advantage + # # + # # INVARIANT: reward and adv MUST NOT carry student gradients. + # # - z_S_cond is detached before entering reward computation. + # # - adv is explicitly detached. + # # - Runtime assertions enforce this. + # # ---------------------------------------------------------------- + # if args.enable_teacher_cfg: + # if args.reward_use_guided: + # z_T_for_rew = z_T_target # already detached (guided, see §5) + # else: + # z_T_for_rew = z_T_cond.float().detach() # non-guided cond (stable default) + # # Both inputs are detached: no student gradient leaks into reward. + # reward = -_stable_kl( + # z_T_for_rew.detach(), z_S_cond.detach(), args.tau) # [B] + # else: + # reward = -loss_aux_cond_sample_last # [B], already detached + + # # Mandatory detach assertions: catch reward/adv gradient leaks early. + # assert not reward.requires_grad, ( + # "[BUG] reward.requires_grad=True — student gradient leaked into reward. " + # "Ensure z_S_cond is detached in reward computation." + # ) + # baseline_ema = 0.99 * baseline_ema + 0.01 * reward.mean().item() + # adv = (reward - baseline_ema).detach() # [B] + # assert not adv.requires_grad, "[BUG] adv.requires_grad=True — explicit detach failed" + + # loss_pg = -(adv * logp).mean() + + # # ---------------------------------------------------------------- + # # Stage 9: Student loss + update + # # ---------------------------------------------------------------- + # opt_student.zero_grad() + + # lambda_ent_eff = args.lambda_ent * (1.0 + 2.0 * use_guided_ratio) + # loss_student = ( + # args.lambda_pg * loss_pg + # + args.lambda_kd * loss_kd + # - lambda_ent_eff * H_mean + # ) + + # # Optional surrogate gradient (DiMO MSE trick — applied to Stage-3 logits z_s) + # loss_surr = None + # if args.use_surrogate_grad: + # with torch.no_grad(): + # logits_A_ref = aux(ids_t.detach(), rope_pos=rpos_t.detach()).sample + # z_A_ref = extract_visual_logits(logits_A_ref.float(), N, K) + # # grad_surr = (p_A - p_T): pushes z_s toward teacher distribution + # p_A_ref = F.softmax(z_A_ref.float() / args.tau_kd, dim=-1).detach() + # p_T_surr = F.softmax(z_T_target / args.tau_kd, dim=-1).detach() + # grad_surr = (p_A_ref - p_T_surr).detach() + # loss_surr = 0.5 * F.mse_loss(z_s, (z_s - grad_surr).detach()) + # loss_student = loss_student + args.lambda_surr * loss_surr + + # loss_student.backward() + # if args.grad_clip > 0: + # torch.nn.utils.clip_grad_norm_(student.parameters(), args.grad_clip) + # opt_student.step() + + # # p_init mixing: save x_hat_4d for next step + # x_hat_prev = x_hat_4d.detach().clone() + + # ---------------------------------------------------------------- + # Stage 7: Student KD forward on x_t — single [2B] when CFG. + # ---------------------------------------------------------------- + if args.enable_teacher_cfg: + logits_S_dual = _get_logits(student(ids_dual.detach(), rope_pos=rpos_dual.detach())).float() + z_S_dual = extract_visual_logits(logits_S_dual, N, K) # [2B, N, K] + z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) + else: + logits_S = _get_logits(student(ids_t.detach(), rope_pos=rpos_t.detach())).float() + z_S_cond = extract_visual_logits(logits_S, N, K) + z_S_uncond = None + + if z_T_cond.shape != z_S_cond.shape: + raise RuntimeError(f"[FATAL] Dual-path shape mismatch: z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}") + + loss_kd_cond = _stable_kl(z_T_target, z_S_cond, args.tau_kd).mean() + loss_kd_uncond_v = torch.tensor(0.0, device=device) + if args.enable_teacher_cfg and (z_S_uncond is not None): + loss_kd_uncond_v = _stable_kl(z_T_uncond.float().detach(), z_S_uncond, args.tau_kd).mean() + loss_kd = loss_kd_cond + args.lambda_kd_uncond * loss_kd_uncond_v + + # ---------------------------------------------------------------- + # Stage 8: reward + advantage (detached) + # ---------------------------------------------------------------- + if args.enable_teacher_cfg and args.reward_use_guided: + z_T_for_rew = z_T_target # already detached + else: + z_T_for_rew = z_T_cond.float().detach() + + reward = -_stable_kl(z_T_for_rew.detach(), z_S_cond.detach(), args.tau) # [B] + assert not reward.requires_grad + + baseline_ema = 0.99 * baseline_ema + 0.01 * reward.mean().item() + adv = (reward - baseline_ema).detach() + assert not adv.requires_grad + + # ---------------------------------------------------------------- + # Stage 9: update student in two backward passes (KD then PG/Ent) + # ---------------------------------------------------------------- + opt_student.zero_grad(set_to_none=True) + + # (9a) KD backward first (frees KD graph) + (args.lambda_kd * loss_kd).backward() + + # (9b) Policy + entropy: need a fresh forward on x_init WITH grad + ids_init, rpos_init, _ = build_ursa_inputs(teacher, txt_cond, x_init, latents_shape, device) + logits_s_pol = _get_logits(student(ids_init, rope_pos=rpos_init)).float() + z_s_pol = extract_visual_logits(logits_s_pol, N, K) + + logp_tok = F.log_softmax(z_s_pol / args.tau, dim=-1) # [B,N,K] + p_s_pol = logp_tok.exp() + + # fixed action: x_hat sampled in Stage 3 (no_grad) + logp_sum = logp_tok.gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) # [B], sum over N tokens + logp = logp_sum / N # [B], per-token average logp (RECOMMENDED) + + H_mean = -(p_s_pol * logp_tok).sum(-1).mean() + + loss_pg = -(adv * logp).mean() + + lambda_ent_eff = args.lambda_ent * (1.0 + 2.0 * use_guided_ratio) + (loss_pg * args.lambda_pg - H_mean * lambda_ent_eff).backward() + + # (optional) surrogate grad — put it here; WARNING: extra forward makes it heavier + loss_surr = None + if args.use_surrogate_grad: + with torch.no_grad(): + logits_A_ref = _get_logits(aux(ids_t.detach(), rope_pos=rpos_t.detach())).float() + z_A_ref = extract_visual_logits(logits_A_ref, N, K) + p_A_ref = F.softmax(z_A_ref / args.tau_kd, dim=-1).detach() + p_T_ref = F.softmax(z_T_target / args.tau_kd, dim=-1).detach() + grad_surr = (p_A_ref - p_T_ref).detach() + loss_surr = 0.5 * F.mse_loss(z_s_pol, (z_s_pol - grad_surr).detach()) + (args.lambda_surr * loss_surr).backward() + + if args.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(student.parameters(), args.grad_clip) + opt_student.step() + + # p_init mixing: save x_hat_4d for next step + x_hat_prev = x_hat_4d.detach() #.clone() + + # ---------------------------------------------------------------- + # Post-step: assertions (step 1), collapse detection, logging + # ---------------------------------------------------------------- + if step == 1: + _run_assertions( + x_init, ids_init, rpos_init, + z_s, p_s, logp, + z_T_cond, z_S_cond, x_t, K, N, B, T, H, W, + teacher.config.lm_vocab_size, + z_T_uncond=z_T_uncond, + z_T_guided=z_T_guided, + dry_run=args.dry_run, + ) + + tok_entropy = _token_histogram_entropy(x_hat, K) + if initial_tok_entropy is None: + initial_tok_entropy = tok_entropy + + if tok_entropy < args.collapse_warn_frac * initial_tok_entropy: + print( + f"[COLLAPSE WARNING] step={step} tok_entropy={tok_entropy:.3f} " + f"initial={initial_tok_entropy:.3f} " + f"ratio={tok_entropy/max(initial_tok_entropy, 1e-8):.2f} < " + f"{args.collapse_warn_frac}. " + "Increase --lambda_ent (try 0.05) or --tau." + ) + + if step % args.log_every == 0 or args.dry_run: + surr_str = f" loss_surr={loss_surr.item():.4f}" if loss_surr is not None else "" + print( + f"[step {step:>6d}] " + f"loss_aux_cond={loss_aux_cond_v_last.item():.3e} " + f"loss_aux_uncond={loss_aux_uncond_v_last.item():.3e} " + f"loss_kd_cond={loss_kd_cond.item():.4f} " + f"loss_kd_uncond={loss_kd_uncond_v.item():.4f} " + f"loss_pg={loss_pg.item():.4f}" + f"{surr_str} " + f"H={H_mean.item():.3f} tok_H={tok_entropy:.3f} " + f"guided_ratio={use_guided_ratio:.2f} " + f"baseline={baseline_ema:.4f} " + f"mean_logp_tok={logp.mean().item():.3f}" + ) + + if args.debug_dump > 0 and step % args.debug_dump == 0: + _dump_debug(dump_dir, step, x_hat, K) + + if not args.dry_run and step % args.save_every == 0: + ckpt_dir = os.path.join(args.out_dir, f"step_{step:06d}") + save_checkpoint(student, ckpt_dir, "student") + save_checkpoint(aux, ckpt_dir, "aux") + + # -- dry_run: full grad-flow check after the single training step ---- + if args.dry_run: + print("\n[dry_run] Running gradient flow debug …") + txt_dbg = next(inf_loader) + B_dbg = txt_dbg.size(0) + x_t_dbg = torch.randint(0, K, (B_dbg, T, H, W), device=device, dtype=torch.long) + txt_u_dbg = (txt_uncond_base.expand(B_dbg, -1) + if args.enable_teacher_cfg else None) + debug_grad_flow( + teacher, student, aux, + txt_dbg, txt_u_dbg, x_t_dbg, latents_shape, device, K, N, + args.tau, args.tau_kd, args.enable_teacher_cfg, + ) + _dry_run_patches_789(teacher, latents_shape, K, N, device) + print("[dry_run] Done. All checks (1-9) PASSED. Exiting.") + return + + # Final save + final_dir = os.path.join(args.out_dir, "final") + save_checkpoint(student, final_dir, "student") + save_checkpoint(aux, final_dir, "aux") + print("[done] Training complete.") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _sample_x_init(B, T, H, W, K, device, x_hat_prev, args): + x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) + if x_hat_prev is not None and args.p_init_mix_ratio > 0: + n_mix = max(1, int(B * args.p_init_mix_ratio)) + x_init[:n_mix] = corrupt_tokens(x_hat_prev[:n_mix], r=args.p_mix_corrupt_frac, K=K) + return x_init + + +def _token_histogram_entropy(x_hat: torch.Tensor, K: int) -> float: + counts = x_hat.flatten().bincount(minlength=K).float() + p = counts / counts.sum() + p = p[p > 0] + return float(-(p * p.log()).sum().item()) + + +def _dump_debug(dump_dir: str, step: int, x_hat: torch.Tensor, K: int): + os.makedirs(dump_dir, exist_ok=True) + counts = x_hat.flatten().bincount(minlength=K).tolist() + with open(os.path.join(dump_dir, f"step_{step:06d}_hist.json"), "w") as fh: + json.dump({"step": step, "counts": counts}, fh) + torch.save(x_hat.cpu(), os.path.join(dump_dir, f"step_{step:06d}_xhat.pt")) + print(f"[debug_dump] step={step} saved to {dump_dir}") + + +def _run_assertions( + x_init, ids_init, rpos_init, + z_s, p_s, logp, + z_T_cond, z_S_cond, x_t, + K, N, B, T, H, W, lm_vocab_size, + z_T_uncond=None, z_T_guided=None, + dry_run=False, +): + """Full shape / value-domain / consistency assertions (run at step=1).""" + print("[assert] Running shape/value assertions …") + + L_plus_N1 = ids_init.size(1) + txt_len = L_plus_N1 - (N + 1) + + # x_init + assert x_init.dtype == torch.long, f"x_init dtype={x_init.dtype}" + assert x_init.min() >= 0 and x_init.max() < K, \ + f"x_init out of [0,K): [{x_init.min()}, {x_init.max()}]" + + # input_ids shape & token value ranges + assert ids_init.shape == (B, L_plus_N1), f"ids_init.shape={ids_init.shape}" + txt_part = ids_init[:, :txt_len] + vis_part = ids_init[:, -N:] + assert (txt_part < lm_vocab_size).all(), \ + f"text tokens bleed into visual range (max={txt_part.max()})" + assert (vis_part >= lm_vocab_size).all(), \ + f"visual tokens not shifted (min={vis_part.min()}, lm_vocab_size={lm_vocab_size})" + assert (vis_part < lm_vocab_size + K).all(), \ + f"visual tokens exceed lm_vocab_size+K (max={vis_part.max()})" + + # rope_pos + assert rpos_init.shape == (B, L_plus_N1, 3), \ + f"rope_pos shape={rpos_init.shape} expected ({B},{L_plus_N1},3)" + + # z_s + assert z_s.shape == (B, N, K), f"z_s.shape={z_s.shape}" + p_err = (p_s.sum(-1) - 1).abs().max().item() + assert p_err < 1e-3, f"p_s not normalised: max deviation={p_err:.2e}" + + # logp + assert not torch.isnan(logp).any(), "logp contains NaN" + assert not torch.isinf(logp).any(), "logp contains Inf" + + # x_t + assert x_t.min() >= 0 and x_t.max() < K, \ + f"x_t out of [0,K) after add_noise: [{x_t.min()}, {x_t.max()}]" + + # Dual-path shape check + assert z_T_cond.shape == z_S_cond.shape, \ + f"Dual-path mismatch: z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}" + assert z_T_cond.shape == (B, N, K), f"z_T_cond.shape={z_T_cond.shape}" + + # z_T logits printout (always in dry_run; also when uncond is available) + if dry_run or z_T_uncond is not None: + print( + f"[assert] z_T_cond shape={z_T_cond.shape} " + f"min={z_T_cond.min():.3f} max={z_T_cond.max():.3f} " + f"mean={z_T_cond.mean():.3f}" + ) + if z_T_uncond is not None: + assert z_T_uncond.shape == (B, N, K), f"z_T_uncond.shape={z_T_uncond.shape}" + print( + f"[assert] z_T_uncond shape={z_T_uncond.shape} " + f"min={z_T_uncond.min():.3f} max={z_T_uncond.max():.3f} " + f"mean={z_T_uncond.mean():.3f}" + ) + if z_T_guided is not None: + assert z_T_guided.shape == (B, N, K), f"z_T_guided.shape={z_T_guided.shape}" + g_min = z_T_guided.min().item() + g_max = z_T_guided.max().item() + g_mean = z_T_guided.mean().item() + print( + f"[assert] z_T_guided shape={z_T_guided.shape} " + f"min={g_min:.3f} max={g_max:.3f} mean={g_mean:.3f}" + ) + # Explosion guard: guided logits must be finite and not excessively large. + assert not torch.isnan(z_T_guided).any(), "z_T_guided contains NaN" + assert not torch.isinf(z_T_guided).any(), "z_T_guided contains Inf" + assert abs(g_min) < 1e4 and abs(g_max) < 1e4, ( + f"z_T_guided magnitude too large: min={g_min:.1e} max={g_max:.1e}. " + f"Reduce --teacher_cfg_scale (currently may amplify outlier logits)." + ) + + print("[assert] All assertions PASSED ✓") + + +def _sanity_check_forward(teacher, scheduler, latents_shape, device, K, verbose=False): + print("[init] Checking logit dimensions …") + T, H, W = latents_shape + N, B, L = T * H * W, 1, 16 + dummy_txt = torch.zeros(B, L, dtype=torch.long, device=device) + dummy_vis = torch.zeros(B, T, H, W, dtype=torch.long, device=device) + with torch.no_grad(): + ids, rpos, _ = build_ursa_inputs(teacher, dummy_txt, dummy_vis, latents_shape, device) + logits = teacher(ids, rope_pos=rpos).sample + lm_head_size = teacher.config.lm_head_size + lm_vocab = teacher.config.lm_vocab_size + print( + f"[init] logits={logits.shape} K={K} " + f"lm_head={lm_head_size} lm_vocab={lm_vocab}" + ) + assert ids.shape == (B, L + N + 1), f"ids shape {ids.shape}" + assert rpos.shape == (B, L + N + 1, 3), f"rpos shape {rpos.shape}" + z = extract_visual_logits(logits.float(), N, K) + assert z.shape == (B, N, K), f"z shape {z.shape}" + assert lm_head_size >= K, f"lm_head_size={lm_head_size} < K={K}" + if verbose: + print("[init] flex_attn state during sanity check:") + _print_flex_attn_state(teacher, "teacher") + print("[init] Forward check OK ✓") + + +# --------------------------------------------------------------------------- +# Dry-run patches 7 / 8 / 9 +# --------------------------------------------------------------------------- + +def _dry_run_patches_789(teacher, latents_shape, K, N, device): + """Three deep self-checks executed only during --dry_run. + + Patch 7 — extract_visual_logits end-to-end alignment: + Run a real teacher forward, manually reconstruct z_manual from raw logits + using the latent_shift / codebook_size convention, and assert the result + matches extract_visual_logits(). Handles the common URSA case where + lm_head outputs K logits directly (latent_shift not applied to logit dim). + + Patch 8 — flex_attn semantics sanity: + If the model exposes set_offsets_by_lens, compare visual-logit mean-delta + between offsets=None (standard causal) and a single-block offset. A large + delta is expected and confirms that our training correctly uses offsets=None. + Gracefully skips when flex_attention is unavailable at runtime. + + Patch 9 — logp / token reshape consistency: + With a small (T=3, H=4, W=5) shape, verify x_hat reshape round-trips and + spot-check 10 token positions against manually computed log-probability. + """ + T, H, W = latents_shape + L_test, B_test = 16, 1 + + print("\n" + "=" * 64) + print("[patch 7/8/9] Running additional dry_run self-checks …") + + # ------------------------------------------------------------------------- + # Build shared dummy inputs used by both patch 7 and patch 8 + # ------------------------------------------------------------------------- + dummy_txt = torch.zeros(B_test, L_test, dtype=torch.long, device=device) + dummy_vis = torch.zeros(B_test, T, H, W, dtype=torch.long, device=device) + with torch.no_grad(): + ids_test, rpos_test, _ = build_ursa_inputs( + teacher, dummy_txt, dummy_vis, latents_shape, device) + logits_full = teacher(ids_test, rope_pos=rpos_test).sample.float() # [1, L+N+1, D] + + D = logits_full.size(-1) # actual logit last-dim (lm_head_size) + latent_shift = teacher.config.lm_vocab_size # text-vocab offset for input token IDs + + # ========================================================================= + # Patch 7 — extract_visual_logits end-to-end alignment + # ========================================================================= + print("\n[7] extract_visual_logits end-to-end alignment …") + z_vis = extract_visual_logits(logits_full, N, K) # [1, N, K] + assert z_vis.shape == (B_test, N, K), f"z_vis.shape={z_vis.shape}" + + if D >= latent_shift + K: + # Full-vocab head: logit dim covers text (0..latent_shift) + visual tokens. + z_seq = logits_full[:, -(N + 1) : -1] # [1, N, D] + z_manual = z_seq[..., latent_shift : latent_shift + K] # [1, N, K] + delta = (z_vis - z_manual).abs().max().item() + print(f" [7] path=full-vocab D={D} latent_shift+K={latent_shift + K}") + print(f" [7] z_vis.shape={z_vis.shape} max|z_vis - z_manual|={delta:.2e}") + assert delta < 1e-5, ( + f"extract_visual_logits mismatch (full-vocab path): delta={delta:.2e}. " + "The function should return logits[..., latent_shift:latent_shift+K]." + ) + print("[7] extract_visual_logits alignment PASSED ✓") + + else: + # Common URSA case: lm_head outputs K logits directly (lm_head_size ≈ K). + # latent_shift is the input token-ID offset, NOT a logit-dimension offset. + # extract_visual_logits handles this as D==K (happy path) or D>K (offset=D-K). + z_seq = logits_full[:, -(N + 1) : -1] # [1, N, D] + if D == K: + delta = (z_vis - z_seq).abs().max().item() + print( + f" [7] SKIP latent_shift formula: D={D} == K={K} " + f"latent_shift={latent_shift}.\n" + f" [7] Explanation: URSA lm_head outputs K visual logits directly.\n" + f" [7] latent_shift={latent_shift} is the input token-ID shift " + f"(raw_code + lm_vocab_size), NOT a logit-dim offset.\n" + f" [7] extract_visual_logits happy-path: z = logits[:, -(N+1):-1] " + f"(no vocab-dim slicing).\n" + f" [7] Fallback check: z_vis == raw causal slice " + f"max_delta={delta:.2e}" + ) + assert delta < 1e-5, ( + f"z_vis != raw causal slice when D==K: delta={delta:.2e}" + ) + else: + # D > K but D < latent_shift + K → extract uses offset = D - K + offset = D - K + z_manual = z_seq[..., offset:] + delta = (z_vis - z_manual).abs().max().item() + print( + f" [7] SKIP latent_shift formula: D={D} < latent_shift+K={latent_shift + K}.\n" + f" [7] extract_visual_logits uses offset={offset} (D-K). " + f"max_delta={delta:.2e}" + ) + assert delta < 1e-5, ( + f"z_vis != z_seq[..., D-K:]: delta={delta:.2e}" + ) + print("[7] extract_visual_logits alignment PASSED (fallback path) ✓") + + # ========================================================================= + # Patch 8 — flex_attn semantics sanity + # ========================================================================= + print("\n[8] flex_attn semantics sanity …") + fa = _probe_flex_attn(teacher) + if fa is None or not hasattr(fa, "set_offsets_by_lens"): + print(" [8] flex_attn.set_offsets_by_lens not available — skip") + print("[8] flex_attn semantics sanity PASSED (skipped — no flex_attn) ✓") + else: + L_total = ids_test.size(1) # L_test + N + 1 + txt_block = L_test + (N + 1) # single-block: all tokens in one block + block_lens = [txt_block] + + try: + # Forward A: offsets=None — standard causal attention (our training config) + _reset_flex_attn(teacher, "teacher") + with torch.no_grad(): + logits_A = teacher(ids_test, rope_pos=rpos_test).sample.float() + z_A = extract_visual_logits(logits_A, N, K) + + # Forward B: set_offsets_by_lens with a single block. + # A single block causes the mask to allow full (bidirectional) attention + # within the block, which differs from standard causal attention. + fa.set_offsets_by_lens(block_lens) + with torch.no_grad(): + logits_B = teacher(ids_test, rope_pos=rpos_test).sample.float() + z_B = extract_visual_logits(logits_B, N, K) + + delta_mean = (z_A - z_B).abs().mean().item() + delta_max = (z_A - z_B).abs().max().item() + print( + f" [8] offsets=None vs set_offsets_by_lens({block_lens}):\n" + f" [8] mean_abs_delta={delta_mean:.4e} max_abs_delta={delta_max:.4e}" + ) + if delta_mean > 1e-3: + print( + f" [8] WARNING: mean_delta={delta_mean:.2e} > 1e-3.\n" + " [8] Single-block flex_attn uses FULL (bidirectional) attention\n" + " [8] inside the block, whereas offsets=None gives standard CAUSAL\n" + " [8] attention. This difference is EXPECTED — it confirms our\n" + " [8] training correctly uses offsets=None (no packed sequences)." + ) + else: + print(f" [8] delta ≤ 1e-3: attention semantics equivalent for this input.") + print("[8] flex_attn semantics sanity PASSED ✓") + + except (NotImplementedError, RuntimeError, Exception) as exc: + print(f" [8] flex_attn runtime not available ({type(exc).__name__}: {exc}) — skip") + print("[8] flex_attn semantics sanity PASSED (runtime skip) ✓") + finally: + _reset_flex_attn(teacher, "teacher") # always restore clean state + + # ========================================================================= + # Patch 9 — logp / token reshape consistency + # ========================================================================= + print("\n[9] logp/token reshape consistency …") + T9, H9, W9 = 3, 4, 5 + N9, B9 = T9 * H9 * W9, 1 # 60 tokens, batch=1 + + torch.manual_seed(99) + z9 = torch.randn(B9, N9, K) + p9 = F.softmax(z9 / 1.0, dim=-1) # [1, 60, K]; each row sums to 1 + + # ----- token sampling --------------------------------------------------- + x_hat_flat = torch.multinomial(p9.view(-1, K), 1) # [N9, 1] (1 sample per row) + x_hat_1d = x_hat_flat.view(B9, N9) # [1, 60] + x_hat_4d = x_hat_1d.view(B9, T9, H9, W9) # [1, 3, 4, 5] + + # reshape round-trip: 1d → 4d → 1d must be lossless + x_hat_back = x_hat_4d.view(B9, N9) + assert torch.equal(x_hat_1d, x_hat_back), ( + f"reshape round-trip FAILED: x_hat_1d != x_hat_4d.view(B,N)\n" + f" x_hat_1d.shape={x_hat_1d.shape} x_hat_back.shape={x_hat_back.shape}" + ) + + # ----- logp computation (mirrors training code) ------------------------- + # logp_all[b, n] = log p9[b, n, x_hat_1d[b, n]] + logp_all = ( + p9.clamp(1e-8).log() + .gather(-1, x_hat_1d.unsqueeze(-1)) + .squeeze(-1) + ) # [B9, N9] + logp_sum = logp_all.sum(-1) # [B9] + + # ----- spot-check 10 random token positions ----------------------------- + torch.manual_seed(7) + positions = torch.randperm(N9)[:10].tolist() + for pos in positions: + tok_id = x_hat_1d[0, pos].item() + logp_man = math.log(max(p9[0, pos, tok_id].item(), 1e-8)) + logp_gat = logp_all[0, pos].item() + diff = abs(logp_man - logp_gat) + assert diff < 1e-6, ( + f"logp mismatch at pos={pos}, tok={tok_id}: " + f"manual={logp_man:.8f} gathered={logp_gat:.8f} diff={diff:.2e}" + ) + + # check logp_sum matches sum of logp_all + logp_sum_manual = logp_all[0].sum().item() + assert abs(logp_sum.item() - logp_sum_manual) < 1e-5, \ + f"logp_sum mismatch: {logp_sum.item():.6f} vs {logp_sum_manual:.6f}" + + print( + f" [9] T={T9},H={H9},W={W9} N={N9} K={K} " + f"x_hat reshape round-trip ✓ " + f"10 logp spot-checks (pos={positions}) ✓ " + f"logp_sum={logp_sum.item():.3f}" + ) + print("[9] logp/token reshape consistency PASSED ✓") + + print("\n" + "=" * 64) + print("[patch 7/8/9] All 3 additional dry_run checks PASSED ✓") + print("=" * 64) + + +if __name__ == "__main__": + main() diff --git a/URSA/src/__init__.py b/URSA/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/URSA/tom/ursa.jpg b/URSA/tom/ursa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c85b83057769f3ee78c5ef87051f0ff836819339 Binary files /dev/null and b/URSA/tom/ursa.jpg differ