import os import json import random import datasets from datasets import load_dataset, Dataset, concatenate_datasets from utils import load_jsonl, lower_keys def load_data(data_name, split, data_dir="./data"): data_file = f"{data_dir}/{data_name}/{split}.jsonl" if os.path.exists(data_file): examples = list(load_jsonl(data_file)) else: if data_name == "math": dataset = load_dataset( "competition_math", split=split, name="main", cache_dir=f"{data_dir}/temp", ) elif data_name == "gsm8k": dataset = load_dataset(data_name, split=split) elif data_name == "svamp": # evaluate on training set + test set dataset = load_dataset("ChilleD/SVAMP", split="train") dataset = concatenate_datasets( [dataset, load_dataset("ChilleD/SVAMP", split="test")] ) elif data_name == "asdiv": dataset = load_dataset("EleutherAI/asdiv", split="validation") dataset = dataset.filter( lambda x: ";" not in x["answer"] ) # remove multi-answer examples elif data_name == "mawps": examples = [] # four sub-tasks for data_name in ["singleeq", "singleop", "addsub", "multiarith"]: sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl")) for example in sub_examples: example["type"] = data_name examples.extend(sub_examples) dataset = Dataset.from_list(examples) elif data_name == "mmlu_stem": dataset = load_dataset("hails/mmlu_no_train", "all", split="test") # only keep stem subjects stem_subjects = [ "abstract_algebra", "astronomy", "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", "college_physics", "computer_security", "conceptual_physics", "electrical_engineering", "elementary_mathematics", "high_school_biology", "high_school_chemistry", "high_school_computer_science", "high_school_mathematics", "high_school_physics", "high_school_statistics", "machine_learning", ] dataset = dataset.rename_column("subject", "type") dataset = dataset.filter(lambda x: x["type"] in stem_subjects) elif data_name == "carp_en": dataset = load_jsonl(f"{data_dir}/carp_en/test.jsonl") else: raise NotImplementedError(data_name) examples = list(dataset) examples = [lower_keys(example) for example in examples] dataset = Dataset.from_list(examples) os.makedirs(f"{data_dir}/{data_name}", exist_ok=True) dataset.to_json(data_file) # add 'idx' in the first column if "idx" not in examples[0]: examples = [{"idx": i, **example} for i, example in enumerate(examples)] # dedepulicate & sort examples = sorted(examples, key=lambda x: x["idx"]) return examples