| """ |
| Data API |
| """ |
| import json |
|
|
| from datasets import Dataset, load_dataset, load_from_disk, concatenate_datasets |
| from prompts import gsm8k_prompt, asdiv_aug_prompt, math_500_prompt, aime_prompt, strategy_qa_prompt, du_prompt |
|
|
| def get_dataset(data_name_or_path, tokenizer, prompt_idx): |
| """ |
| Args: |
| data_name_or_path: dataset name or path |
| tokenizer: tokenizer |
| prompt_idx: which query prompt to use |
| Returns: |
| dataset: dataset |
| """ |
|
|
| |
| if "gsm8k" in data_name_or_path.lower(): |
| try: |
| dataset = load_from_disk(data_name_or_path)['test'] |
| except: |
| dataset = load_dataset("openai/gsm8k", "socratic")["test"] |
| question_col = "question" |
| answer_col = "answer" |
| |
| elif "asdiv-aug" in data_name_or_path.lower(): |
| try: |
| dataset = load_from_disk(data_name_or_path)['test'] |
| except: |
| dataset = load_dataset("xuyige/ASDiv-Aug")["test"] |
| question_col = "question" |
| answer_col = "answer" |
|
|
| elif "math-500" in data_name_or_path.lower(): |
| try: |
| dataset = load_from_disk(data_name_or_path)['test'] |
| except: |
| dataset = load_dataset("HuggingFaceH4/MATH-500")["test"] |
| question_col = "problem" |
| answer_col = "answer" |
|
|
| elif "aime_2024" in data_name_or_path.lower(): |
| try: |
| dataset = load_from_disk(data_name_or_path) |
| except: |
| dataset = load_dataset("Maxwell-Jia/AIME_2024")['train'] |
| question_col = "Problem" |
| answer_col = "Answer" |
| |
| elif "aime2025" in data_name_or_path.lower(): |
| try: |
| dataset = load_from_disk(data_name_or_path) |
| except: |
| dataset = concatenate_datasets([ |
| load_dataset("opencompass/AIME2025", "AIME2025-I")['test'], |
| load_dataset("opencompass/AIME2025", "AIME2025-II")['test'], |
| ]) |
| question_col = "question" |
| answer_col = "answer" |
|
|
| elif "strategyqa" in data_name_or_path.lower(): |
| return get_strategyqa(tokenizer, prompt_idx) |
|
|
| elif "date_understanding" in data_name_or_path.lower(): |
| dataset = load_dataset("maveriq/bigbenchhard", "date_understanding")['train'] |
| question_col = "input" |
| answer_col = "target" |
|
|
| else: |
| raise ValueError(f"Unsupported dataset: {data_name_or_path}") |
|
|
| |
| def preprocess_function(examples): |
| ''' |
| Preprocess dataset |
| |
| Args: |
| examples: dataset examples |
| |
| Returns: |
| formatted: formatted dataset |
| ''' |
| prompt = [] |
| formatted = [] |
| answers = examples[answer_col] |
| questions = examples[question_col] |
| for q in questions: |
| if "gsm8k" in data_name_or_path.lower(): |
| messages = gsm8k_prompt(q, prompt_idx) |
| elif "asdiv-aug" in data_name_or_path.lower(): |
| messages = asdiv_aug_prompt(q, prompt_idx) |
| elif "math-500" in data_name_or_path.lower(): |
| messages = math_500_prompt(q, prompt_idx) |
| elif "aime" in data_name_or_path.lower(): |
| messages = aime_prompt(q, prompt_idx) |
| elif "date_understanding" in data_name_or_path.lower(): |
| messages = du_prompt(q, prompt_idx) |
| else: |
| raise ValueError(f"Unsupported dataset: {data_name_or_path}") |
|
|
| prompt.append(messages) |
| formatted.append(tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| )) |
| if "aime" in data_name_or_path.lower() and "2025" in data_name_or_path.lower(): |
| answers = [ans.replace('^\circ', '') for ans in answers] |
| if "date_understanding" in data_name_or_path.lower(): |
| answers = [ans[1] for ans in answers] |
| return { |
| "prompt": prompt, |
| "formatted": formatted, |
| "question": questions, |
| "answer": answers, |
| } |
|
|
| dataset = dataset.map(preprocess_function, batched=True, load_from_cache_file=False) |
| return dataset |
|
|
| def get_strategyqa(tokenizer, prompt_idx): |
| prompt = [] |
| formatted = [] |
| answers = [] |
| questions = [] |
| with open('strategyqa_train.json', 'r') as f: |
| data = json.load(f) |
| for ins in data: |
| q, a = ins['question'], ins['answer'] |
| msg = strategy_qa_prompt(q, prompt_idx) |
| questions.append(q) |
| answers.append(a) |
| prompt.append(msg) |
| formatted.append(tokenizer.apply_chat_template( |
| msg, tokenize=False, add_generation_prompt=True |
| )) |
| return Dataset.from_dict({ |
| "prompt": prompt, |
| "formatted": formatted, |
| "question": questions, |
| "answer": answers, |
| }) |
|
|