| |
| """ |
| Download high-quality STEM datasets for SHOREKEEPER |
| Math, Code, Science - No random web text |
| """ |
|
|
| import json |
| from pathlib import Path |
| from datasets import load_dataset |
|
|
| def download_stem_data(): |
| print("=" * 70) |
| print("DOWNLOADING STEM DATASETS") |
| print("=" * 70) |
| |
| data_dir = Path("./data/stem") |
| data_dir.mkdir(parents=True, exist_ok=True) |
| |
| all_data = [] |
| |
| |
| print("\n1. MetaMathQA (395k math problems)...") |
| try: |
| dataset = load_dataset("meta-math/MetaMathQA", split="train") |
| print(f" Loading {len(dataset)} examples...") |
| for item in dataset: |
| all_data.append({ |
| "prompt": item.get("query", ""), |
| "response": f"|special_token| {item.get('response', '')} |special_token|", |
| "source": "metamath" |
| }) |
| print(f" β Added {len(dataset)} math examples") |
| except Exception as e: |
| print(f" β Failed: {e}") |
| |
| |
| print("\n2. CodeFeedback (1.2M code examples - taking 200k)...") |
| try: |
| dataset = load_dataset("m-a-p/CodeFeedback", split="train[:200000]") |
| print(f" Loading {len(dataset)} examples...") |
| for item in dataset: |
| instruction = item.get("instruction", "") |
| output = item.get("output", "") |
| if instruction and output: |
| all_data.append({ |
| "prompt": instruction, |
| "response": f"|special_token| Here's the code:\n{output} |special_token|", |
| "source": "codefeedback" |
| }) |
| print(f" β Added {len(dataset)} code examples") |
| except Exception as e: |
| print(f" β Failed: {e}") |
| |
| |
| print("\n3. NuminaMath-CoT (860k math problems - taking 200k)...") |
| try: |
| dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train[:200000]") |
| print(f" Loading {len(dataset)} examples...") |
| for item in dataset: |
| problem = item.get("problem", "") |
| solution = item.get("solution", "") |
| if problem and solution: |
| all_data.append({ |
| "prompt": problem, |
| "response": f"|special_token| Let me solve this step by step.\n{solution} |special_token|", |
| "source": "numinamath" |
| }) |
| print(f" β Added {len(dataset)} math examples") |
| except Exception as e: |
| print(f" β Failed: {e}") |
| |
| |
| print("\n4. ScienceQA (21k science questions)...") |
| try: |
| dataset = load_dataset("derek-thomas/ScienceQA", split="train") |
| print(f" Loading {len(dataset)} examples...") |
| for item in dataset: |
| question = item.get("question", "") |
| answer = item.get("answer", "") |
| if question and answer: |
| all_data.append({ |
| "prompt": question, |
| "response": f"|special_token| Science explanation:\n{answer} |special_token|", |
| "source": "scienceqa" |
| }) |
| print(f" β Added {len(dataset)} science examples") |
| except Exception as e: |
| print(f" β Failed: {e}") |
| |
| |
| print("\n5. GSM8K (8.5k grade school math)...") |
| try: |
| dataset = load_dataset("gsm8k", "main", split="train") |
| print(f" Loading {len(dataset)} examples...") |
| for item in dataset: |
| question = item.get("question", "") |
| answer = item.get("answer", "").split("####")[-1].strip() |
| if question and answer: |
| all_data.append({ |
| "prompt": question, |
| "response": f"|special_token| {answer} |special_token|", |
| "source": "gsm8k" |
| }) |
| print(f" β Added {len(dataset)} math examples") |
| except Exception as e: |
| print(f" β Failed: {e}") |
| |
| print("\n" + "=" * 70) |
| print(f"TOTAL STEM EXAMPLES: {len(all_data):,}") |
| print("=" * 70) |
| |
| |
| sources = {} |
| for item in all_data: |
| src = item['source'] |
| sources[src] = sources.get(src, 0) + 1 |
| |
| print("\nBreakdown by source:") |
| for src, count in sources.items(): |
| print(f" {src}: {count:,}") |
| |
| |
| print("\nSaving to disk...") |
| with open(data_dir / "stem_train.jsonl", "w") as f: |
| for item in all_data: |
| f.write(json.dumps(item) + "\n") |
| |
| print(f"β Saved to: {data_dir}/stem_train.jsonl") |
| print(f" Total size: {len(all_data):,} examples") |
| |
| |
| split_idx = int(len(all_data) * 0.95) |
| train = all_data[:split_idx] |
| val = all_data[split_idx:] |
| |
| with open(data_dir / "stem_val.jsonl", "w") as f: |
| for item in val: |
| f.write(json.dumps(item) + "\n") |
| |
| print(f" Train: {len(train):,}") |
| print(f" Val: {len(val):,}") |
|
|
| if __name__ == "__main__": |
| download_stem_data() |
|
|