|
|
""" |
|
|
MM Math Dataset Preprocessing Script |
|
|
|
|
|
This script downloads, splits, and preprocesses the MM Math dataset. |
|
|
The preprocessed data will be saved to the data/mm_math directory with |
|
|
fields: prompt, completion, solution, image_path |
|
|
|
|
|
Usage: |
|
|
# Using config file |
|
|
uv run scripts/mm_math_process.py --config configs/latent_memory/mm_math.yaml |
|
|
|
|
|
# Manual parameters |
|
|
uv run scripts/mm_math_process.py --output_dir data/mm_math --val_ratio 0.05 --test_ratio 0.05 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import logging |
|
|
import argparse |
|
|
from typing import Dict, List, Optional |
|
|
import requests |
|
|
import zipfile |
|
|
import io |
|
|
import yaml |
|
|
from datasets import load_dataset, DatasetDict |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
|
|
|
def download_data(cache_path: str) -> str: |
|
|
"""Download MM Math JSONL data. |
|
|
|
|
|
Args: |
|
|
cache_path: Directory to save downloaded data |
|
|
|
|
|
Returns: |
|
|
Path to downloaded JSONL file |
|
|
""" |
|
|
os.makedirs(cache_path, exist_ok=True) |
|
|
jsonl_url = "https://huggingface.co/datasets/THU-KEG/MM_Math/resolve/main/MM_Math/MM_Math.jsonl" |
|
|
jsonl_path = os.path.join(cache_path, "mm_math.jsonl") |
|
|
|
|
|
if os.path.exists(jsonl_path): |
|
|
logging.info(f"JSONL file already exists at {jsonl_path}") |
|
|
else: |
|
|
logging.info(f"Downloading MM_Math dataset to {jsonl_path}") |
|
|
resp = requests.get(jsonl_url, timeout=120) |
|
|
resp.raise_for_status() |
|
|
with open(jsonl_path, "wb") as f: |
|
|
f.write(resp.content) |
|
|
logging.info("Download complete") |
|
|
|
|
|
return jsonl_path |
|
|
|
|
|
|
|
|
def download_images(image_root: str): |
|
|
"""Download MM Math images. |
|
|
|
|
|
Args: |
|
|
image_root: Directory to extract images to |
|
|
""" |
|
|
if os.path.isdir(image_root) and len(os.listdir(image_root)) > 0: |
|
|
logging.info(f"Images already exist at {image_root}") |
|
|
return |
|
|
|
|
|
os.makedirs(image_root, exist_ok=True) |
|
|
logging.info(f"Downloading MM_Math images to {image_root}") |
|
|
zip_url = "https://huggingface.co/datasets/THU-KEG/MM_Math/resolve/main/MM_Math/MM_Math.zip" |
|
|
resp = requests.get(zip_url, timeout=300) |
|
|
resp.raise_for_status() |
|
|
with zipfile.ZipFile(io.BytesIO(resp.content)) as z: |
|
|
z.extractall(image_root) |
|
|
logging.info("Image download complete") |
|
|
|
|
|
|
|
|
def split_dataset(jsonl_path: str, val_ratio: float = 0.05, test_ratio: float = 0.05) -> DatasetDict: |
|
|
"""Split dataset into train/valid/test. |
|
|
|
|
|
Args: |
|
|
jsonl_path: Path to the JSONL file |
|
|
val_ratio: Validation set ratio |
|
|
test_ratio: Test set ratio |
|
|
|
|
|
Returns: |
|
|
DatasetDict with train/valid/test splits |
|
|
""" |
|
|
if val_ratio < 0 or test_ratio < 0 or (val_ratio + test_ratio) >= 1.0: |
|
|
raise ValueError("Invalid val_ratio/test_ratio; ensure 0 <= ratios and sum < 1.0") |
|
|
|
|
|
logging.info(f"Loading dataset from {jsonl_path}") |
|
|
raw_ds = load_dataset("json", data_files={"train": jsonl_path}) |
|
|
base_train = raw_ds["train"] |
|
|
|
|
|
logging.info(f"Splitting dataset: val_ratio={val_ratio}, test_ratio={test_ratio}") |
|
|
holdout_ratio = val_ratio + test_ratio |
|
|
first_split = base_train.train_test_split(test_size=holdout_ratio, seed=42, shuffle=True) |
|
|
split_train = first_split["train"] |
|
|
holdout = first_split["test"] |
|
|
|
|
|
if holdout_ratio > 0: |
|
|
test_fraction = test_ratio / holdout_ratio |
|
|
second_split = holdout.train_test_split(test_size=test_fraction, seed=42, shuffle=True) |
|
|
split_valid = second_split["train"] |
|
|
split_test = second_split["test"] |
|
|
else: |
|
|
split_valid = split_train.select(range(0)) |
|
|
split_test = split_train.select(range(0)) |
|
|
|
|
|
dataset_dict = DatasetDict({ |
|
|
"train": split_train, |
|
|
"valid": split_valid, |
|
|
"test": split_test, |
|
|
}) |
|
|
|
|
|
logging.info(f"Split sizes - train: {len(split_train)}, valid: {len(split_valid)}, test: {len(split_test)}") |
|
|
return dataset_dict |
|
|
|
|
|
|
|
|
def preprocess_batch(batch: Dict, image_root: str) -> Dict: |
|
|
"""Preprocess a batch of examples. |
|
|
|
|
|
Args: |
|
|
batch: Batch of raw examples |
|
|
image_root: Root directory for images |
|
|
|
|
|
Returns: |
|
|
Preprocessed batch with fields: |
|
|
- prompt: formatted question prompt |
|
|
- completion: original full solution text |
|
|
- solution: extracted boxed answer (for reward computation) |
|
|
- image_path: path to image file |
|
|
""" |
|
|
def _format_answer(answer: str) -> str: |
|
|
answer = (answer or "").strip() |
|
|
if answer.startswith("\\boxed{") and answer.endswith("}"): |
|
|
return answer |
|
|
return "\\boxed{" + answer + "}" |
|
|
|
|
|
def _extract_answer_from_solution(solution_text: str) -> str: |
|
|
"""Extract answer from solution text. |
|
|
|
|
|
Extracts content between matching braces after \boxed{. |
|
|
This handles nested braces correctly, e.g., \boxed{72^{\circ}} extracts 72^{\circ} |
|
|
""" |
|
|
if not solution_text: |
|
|
return "" |
|
|
|
|
|
|
|
|
boxed_pattern = r"\\\\?boxed\{" |
|
|
matches = list(re.finditer(boxed_pattern, solution_text)) |
|
|
|
|
|
if matches: |
|
|
|
|
|
last_match = matches[-1] |
|
|
start_pos = last_match.end() |
|
|
|
|
|
|
|
|
brace_count = 1 |
|
|
pos = start_pos |
|
|
while pos < len(solution_text) and brace_count > 0: |
|
|
if solution_text[pos] == '{': |
|
|
brace_count += 1 |
|
|
elif solution_text[pos] == '}': |
|
|
brace_count -= 1 |
|
|
pos += 1 |
|
|
|
|
|
if brace_count == 0: |
|
|
|
|
|
return solution_text[start_pos:pos-1].strip() |
|
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
format_template = r"""Solve the problem and output the answer in the format of \\boxed{your answer}.""" |
|
|
prompt_template = "Question: {prompt}\n" |
|
|
|
|
|
|
|
|
questions: List[str] = batch.get("question") or [""] * len(batch.get("solution", [])) |
|
|
answers_src: List[str] = batch.get("answer") or [""] * len(questions) |
|
|
solutions_src: List[str] = batch.get("solution") or [""] * len(questions) |
|
|
file_names_src = batch.get("file_name", [None] * len(questions)) |
|
|
image_paths_src = batch.get("image_path", [None] * len(questions)) |
|
|
file_names: List[str] = [fn if fn is not None else ip for fn, ip in zip(file_names_src, image_paths_src)] |
|
|
|
|
|
prompts: List[str] = [] |
|
|
completions: List[str] = [] |
|
|
solutions: List[str] = [] |
|
|
image_paths: List[str] = [] |
|
|
|
|
|
for q, a_src, sol_src, fn in zip(questions, answers_src, solutions_src, file_names): |
|
|
processed_prompt = format_template + prompt_template.format(prompt=(q or "").strip()) |
|
|
|
|
|
|
|
|
answer_text = (a_src or "").strip() |
|
|
if len(answer_text) == 0: |
|
|
answer_text = _extract_answer_from_solution((sol_src or "").strip()) |
|
|
|
|
|
|
|
|
|
|
|
completion_text = (sol_src or "").strip() |
|
|
solution_label = _format_answer(answer_text) |
|
|
|
|
|
prompts.append(processed_prompt) |
|
|
completions.append(completion_text) |
|
|
solutions.append(solution_label) |
|
|
|
|
|
if fn is not None: |
|
|
image_paths.append(os.path.join(image_root, fn)) |
|
|
else: |
|
|
image_paths.append(None) |
|
|
|
|
|
return { |
|
|
"prompt": prompts, |
|
|
"completion": completions, |
|
|
"solution": solutions, |
|
|
"image_path": image_paths, |
|
|
} |
|
|
|
|
|
|
|
|
def preprocess_dataset(dataset_dict: DatasetDict, image_root: str, batch_size: int = 512) -> DatasetDict: |
|
|
"""Preprocess all splits. |
|
|
|
|
|
Args: |
|
|
dataset_dict: Raw dataset dictionary |
|
|
image_root: Root directory for images |
|
|
batch_size: Batch size for processing |
|
|
|
|
|
Returns: |
|
|
Preprocessed DatasetDict with fields: prompt, completion, solution, image_path |
|
|
""" |
|
|
keep_keys = ["prompt", "completion", "solution", "image_path"] |
|
|
|
|
|
def _map(split): |
|
|
logging.info(f"Preprocessing {split} split with batch_size={batch_size}") |
|
|
ds = dataset_dict[split].map( |
|
|
lambda batch: preprocess_batch(batch, image_root), |
|
|
batched=True, |
|
|
batch_size=batch_size, |
|
|
num_proc=None, |
|
|
remove_columns=dataset_dict[split].column_names, |
|
|
desc=f"MM_Math preprocess ({split})", |
|
|
) |
|
|
|
|
|
|
|
|
def has_valid_solution(example): |
|
|
solution = example.get("solution", "") |
|
|
return solution is not None and len(solution.strip()) > 0 |
|
|
|
|
|
original_size = len(ds) |
|
|
ds = ds.filter(has_valid_solution, num_proc=None, desc=f"Filter empty solutions ({split})") |
|
|
filtered_size = len(ds) |
|
|
if original_size != filtered_size: |
|
|
logging.warning(f"{split}: Filtered out {original_size - filtered_size} samples with empty solutions") |
|
|
|
|
|
logging.info(f"Preprocessing done for {split}: {len(ds)} samples") |
|
|
return ds |
|
|
|
|
|
processed_dict = DatasetDict() |
|
|
for split_name in dataset_dict.keys(): |
|
|
processed_dict[split_name] = _map(split_name) |
|
|
|
|
|
return processed_dict |
|
|
|
|
|
|
|
|
def save_processed_data(dataset_dict: DatasetDict, output_dir: str): |
|
|
"""Save preprocessed data to JSON files. |
|
|
|
|
|
Args: |
|
|
dataset_dict: Preprocessed dataset dictionary |
|
|
output_dir: Output directory to save JSON files |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
for split_name, dataset in dataset_dict.items(): |
|
|
output_path = os.path.join(output_dir, f"{split_name}.json") |
|
|
logging.info(f"Saving {split_name} split to {output_path} ({len(dataset)} samples)") |
|
|
|
|
|
|
|
|
data = [dict(item) for item in dataset] |
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
logging.info(f"All splits saved to {output_dir}") |
|
|
|
|
|
|
|
|
def verify_processed_data(dataset_dict: DatasetDict): |
|
|
"""Verify processed data quality. |
|
|
|
|
|
Args: |
|
|
dataset_dict: Processed dataset dictionary |
|
|
""" |
|
|
pattern = re.compile(r"\\boxed\{[^}]+\}") |
|
|
found = None |
|
|
checked = 0 |
|
|
nonempty = 0 |
|
|
total = 0 |
|
|
|
|
|
for split_name in ("train", "valid", "test"): |
|
|
if split_name in dataset_dict: |
|
|
ds = dataset_dict[split_name] |
|
|
total += len(ds) |
|
|
limit = min(100, len(ds)) |
|
|
for i in range(limit): |
|
|
ex = ds[i] |
|
|
sol = ex.get("solution", "") or "" |
|
|
if pattern.search(sol): |
|
|
nonempty += 1 |
|
|
if found is None: |
|
|
found = ex |
|
|
checked += 1 |
|
|
|
|
|
if found is not None: |
|
|
logging.info(f"Example after preprocessing (verified boxed): {found}") |
|
|
logging.info(f"Boxed solution stats: checked={checked}, nonempty_boxed={nonempty}, total={total}") |
|
|
|
|
|
|
|
|
def load_config_from_yaml(config_path: str) -> Dict: |
|
|
"""Load configuration from yaml file. |
|
|
|
|
|
Args: |
|
|
config_path: Path to yaml config file |
|
|
|
|
|
Returns: |
|
|
Dictionary with configuration parameters |
|
|
""" |
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
dataset_config = config.get('datasets', {}).get('mm_math', {}) |
|
|
|
|
|
|
|
|
sft_config = dataset_config.get('sft', {}) |
|
|
grpo_config = dataset_config.get('grpo', {}) |
|
|
|
|
|
|
|
|
mode_config = sft_config if sft_config else grpo_config |
|
|
|
|
|
return { |
|
|
'cache_path': mode_config.get('cache_path') or dataset_config.get('cache_path'), |
|
|
'val_ratio': mode_config.get('val_ratio') or dataset_config.get('val_ratio', 0.05), |
|
|
'test_ratio': mode_config.get('test_ratio') or dataset_config.get('test_ratio', 0.05), |
|
|
'image_root': mode_config.get('image_root') or dataset_config.get('image_root'), |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Preprocess MM Math dataset") |
|
|
|
|
|
|
|
|
parser.add_argument("--config", type=str, default=None, |
|
|
help="Path to yaml config file (e.g., configs/latent_memory/mm_math.yaml)") |
|
|
|
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default=None, |
|
|
help="Output directory for preprocessed data (overrides config)") |
|
|
parser.add_argument("--cache_dir", type=str, default=None, |
|
|
help="Cache directory for downloaded raw data (overrides config)") |
|
|
parser.add_argument("--image_root", type=str, default=None, |
|
|
help="Directory for image files (overrides config)") |
|
|
parser.add_argument("--val_ratio", type=float, default=None, |
|
|
help="Validation set ratio (overrides config)") |
|
|
parser.add_argument("--test_ratio", type=float, default=None, |
|
|
help="Test set ratio (overrides config)") |
|
|
|
|
|
|
|
|
parser.add_argument("--batch_size", type=int, default=512, |
|
|
help="Batch size for preprocessing") |
|
|
parser.add_argument("--skip_download", action="store_true", |
|
|
help="Skip downloading data and images") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.config: |
|
|
if not os.path.exists(args.config): |
|
|
raise FileNotFoundError(f"Config file not found: {args.config}") |
|
|
|
|
|
logging.info(f"Loading configuration from {args.config}") |
|
|
yaml_config = load_config_from_yaml(args.config) |
|
|
|
|
|
|
|
|
|
|
|
output_dir = args.output_dir or "data/mm_math" |
|
|
cache_dir = args.cache_dir or ".cache/mm_math" |
|
|
image_root = args.image_root or yaml_config.get('image_root') or "dataset/mm_math/images/MM_Math" |
|
|
val_ratio = args.val_ratio if args.val_ratio is not None else yaml_config.get('val_ratio', 0.05) |
|
|
test_ratio = args.test_ratio if args.test_ratio is not None else yaml_config.get('test_ratio', 0.05) |
|
|
else: |
|
|
|
|
|
output_dir = args.output_dir or "data/mm_math" |
|
|
cache_dir = args.cache_dir or ".cache/mm_math" |
|
|
image_root = args.image_root or "dataset/mm_math/images/MM_Math" |
|
|
val_ratio = args.val_ratio if args.val_ratio is not None else 0.05 |
|
|
test_ratio = args.test_ratio if args.test_ratio is not None else 0.05 |
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info("MM Math Dataset Preprocessing") |
|
|
logging.info("=" * 80) |
|
|
if args.config: |
|
|
logging.info(f"Config file: {args.config}") |
|
|
logging.info(f"Output directory: {output_dir}") |
|
|
logging.info(f"Cache directory: {cache_dir}") |
|
|
logging.info(f"Image root: {image_root}") |
|
|
logging.info(f"Val ratio: {val_ratio}, Test ratio: {test_ratio}") |
|
|
logging.info(f"Batch size: {args.batch_size}") |
|
|
logging.info("=" * 80) |
|
|
|
|
|
|
|
|
jsonl_path = os.path.join(cache_dir, "mm_math.jsonl") |
|
|
|
|
|
if args.skip_download: |
|
|
logging.info("Skipping download as --skip_download flag is set") |
|
|
if not os.path.exists(jsonl_path): |
|
|
raise FileNotFoundError( |
|
|
f"JSONL file not found at {jsonl_path}. " |
|
|
f"Remove --skip_download flag to download." |
|
|
) |
|
|
else: |
|
|
|
|
|
jsonl_path = download_data(cache_dir) |
|
|
download_images(image_root) |
|
|
|
|
|
|
|
|
dataset_dict = split_dataset(jsonl_path, val_ratio, test_ratio) |
|
|
|
|
|
|
|
|
processed_dict = preprocess_dataset(dataset_dict, image_root, args.batch_size) |
|
|
|
|
|
|
|
|
verify_processed_data(processed_dict) |
|
|
|
|
|
|
|
|
save_processed_data(processed_dict, output_dir) |
|
|
|
|
|
logging.info("=" * 80) |
|
|
logging.info("Preprocessing complete!") |
|
|
logging.info("=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|