rae-training / src /rae_data_formatter.py
TrueV1sion123's picture
Upload src/rae_data_formatter.py with huggingface_hub
3143539 verified
"""
RAE Data Formatter
═══════════════════════════════════════════════════════════════
Converts existing datasets into RAE-structured format.
Supports converting:
1. Standard Q&A datasets β†’ RAE-structured chat
2. Chain-of-thought datasets β†’ RAE phases (mapping reasoning steps to phases)
3. Code datasets β†’ RAE-structured code reasoning
4. Custom formats via pluggable formatters
═══════════════════════════════════════════════════════════════
"""
import json
import re
from pathlib import Path
from typing import Callable, Optional
from rae_tokenizer_utils import PHASE_TAGS, validate_rae_response
# ── System Prompts by Domain ──────────────────────────────────
SYSTEM_PROMPTS = {
"general": (
"You are an RAE-trained cognitive reasoner. For every problem, "
"work through all four phases: SATURATION (explore without judgment), "
"ABSTRACTION (extract minimal structure), DESCENT (concrete implementation), "
"INTEGRATION (meta-learning). Use XML phase tags."
),
"code": (
"You are an RAE-trained software engineer. For every coding task, "
"work through: SATURATION (understand requirements, edge cases, constraints), "
"ABSTRACTION (identify core algorithm/pattern), DESCENT (implement and test), "
"INTEGRATION (what was learned, what generalizes). Use XML phase tags."
),
"analysis": (
"You are an RAE-trained strategic analyst. For every analysis, "
"work through: SATURATION (gather all signals, flag anomalies), "
"ABSTRACTION (identify root mechanism), DESCENT (specific predictions and recommendations), "
"INTEGRATION (confidence assessment, what would change the conclusion). Use XML phase tags."
),
"reasoning": (
"You are an RAE-trained reasoner. For every problem, "
"work through: SATURATION (map the full problem space without premature conclusions), "
"ABSTRACTION (what's the underlying structure?), DESCENT (test implications concretely), "
"INTEGRATION (update beliefs, identify next questions). Use XML phase tags."
),
}
def cot_to_rae(
question: str,
chain_of_thought: str,
answer: str,
domain: str = "general",
) -> Optional[dict]:
"""
Convert a chain-of-thought example to RAE structure.
Heuristic mapping:
- First ~30% of CoT β†’ Saturation (exploration/observation)
- Next ~20% β†’ Abstraction (key insight identification)
- Next ~30% β†’ Descent (working through specifics)
- Final ~20% + answer β†’ Integration (conclusion + meta-learning)
"""
cot_sentences = [s.strip() for s in re.split(r'[.!?]+', chain_of_thought) if s.strip()]
total = len(cot_sentences)
if total < 4:
return None # Too short to meaningfully decompose
# Split into phases
sat_end = int(total * 0.3)
abs_end = int(total * 0.5)
desc_end = int(total * 0.8)
saturation = ". ".join(cot_sentences[:sat_end]) + "."
abstraction = ". ".join(cot_sentences[sat_end:abs_end]) + "."
descent = ". ".join(cot_sentences[abs_end:desc_end]) + "."
integration = ". ".join(cot_sentences[desc_end:]) + f"\n\nFinal answer: {answer}"
system = SYSTEM_PROMPTS.get(domain, SYSTEM_PROMPTS["general"])
rae_response = (
f"<SATURATION>\n{saturation}\n</SATURATION>\n\n"
f"<ABSTRACTION>\n{abstraction}\n</ABSTRACTION>\n\n"
f"<DESCENT>\n{descent}\n</DESCENT>\n\n"
f"<INTEGRATION>\n{integration}\n</INTEGRATION>"
)
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": question},
{"role": "assistant", "content": rae_response},
],
"metadata": {
"domain": domain,
"source_format": "cot",
"rae_version": "1.0",
}
}
def qa_to_rae(
question: str,
answer: str,
domain: str = "general",
explanation: str = "",
) -> dict:
"""
Convert a simple Q&A pair to RAE structure.
Since there's no reasoning chain, we create a minimal
RAE scaffold that the model will learn to fill richly.
"""
system = SYSTEM_PROMPTS.get(domain, SYSTEM_PROMPTS["general"])
rae_response = (
f"<SATURATION>\n"
f"The question asks: {question}\n"
f"Key elements to consider: {explanation or 'Let me explore the problem space.'}\n"
f"</SATURATION>\n\n"
f"<ABSTRACTION>\n"
f"The core structure of this problem is about identifying the right approach.\n"
f"</ABSTRACTION>\n\n"
f"<DESCENT>\n"
f"{answer}\n"
f"</DESCENT>\n\n"
f"<INTEGRATION>\n"
f"This reinforces the principle that careful problem decomposition "
f"leads to clearer solutions.\n"
f"</INTEGRATION>"
)
return {
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": question},
{"role": "assistant", "content": rae_response},
],
"metadata": {
"domain": domain,
"source_format": "qa",
"rae_version": "1.0",
}
}
def convert_hf_dataset(
dataset_name: str,
formatter: Callable,
output_path: str,
max_examples: int = 1000,
train_split: str = "train",
):
"""
Convert a HuggingFace dataset to RAE format.
Args:
dataset_name: HF dataset identifier (e.g., "gsm8k")
formatter: Function that converts a single example
output_path: Where to write the JSONL output
max_examples: Maximum examples to convert
train_split: Which split to use
"""
from datasets import load_dataset
print(f"Loading {dataset_name}...")
dataset = load_dataset(dataset_name, split=train_split)
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
converted = 0
skipped = 0
with open(output, "w") as f:
for i, example in enumerate(dataset):
if converted >= max_examples:
break
result = formatter(example)
if result:
validation = validate_rae_response(result["messages"][-1]["content"])
if validation["is_valid"] or len(validation["phases_found"]) >= 3:
f.write(json.dumps(result) + "\n")
converted += 1
else:
skipped += 1
else:
skipped += 1
print(f"Converted {converted} examples ({skipped} skipped) β†’ {output}")
return converted
# ── Pre-built Formatters for Popular Datasets ─────────────────
def format_gsm8k(example: dict) -> Optional[dict]:
"""Format GSM8K math reasoning to RAE."""
question = example.get("question", "")
answer_text = example.get("answer", "")
# GSM8K format: reasoning steps separated by \n, final answer after ####
parts = answer_text.split("####")
reasoning = parts[0].strip() if len(parts) > 1 else answer_text
final_answer = parts[1].strip() if len(parts) > 1 else ""
return cot_to_rae(question, reasoning, final_answer, domain="reasoning")
def format_code_alpaca(example: dict) -> Optional[dict]:
"""Format Code Alpaca to RAE."""
instruction = example.get("instruction", "")
output = example.get("output", "")
return qa_to_rae(instruction, output, domain="code")
def format_openassistant(example: dict) -> Optional[dict]:
"""Format OpenAssistant conversations to RAE."""
text = example.get("text", "")
if not text:
return None
# Simple: wrap the whole response in RAE structure
return qa_to_rae(
"Respond helpfully to the following conversation.",
text,
domain="general",
)
# ── Available Formatters Registry ─────────────────────────────
FORMATTERS = {
"gsm8k": ("gsm8k", "main", format_gsm8k),
"code_alpaca": ("sahil2801/CodeAlpaca-20k", None, format_code_alpaca),
"openassistant": ("timdettmers/openassistant-guanaco", None, format_openassistant),
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert HF datasets to RAE format")
parser.add_argument("--dataset", type=str, required=True, choices=list(FORMATTERS.keys()))
parser.add_argument("--output", type=str, default="data/rae_training_data/converted.jsonl")
parser.add_argument("--max_examples", type=int, default=500)
args = parser.parse_args()
dataset_id, config, formatter = FORMATTERS[args.dataset]
from datasets import load_dataset
split_name = "train"
convert_hf_dataset(
dataset_name=dataset_id,
formatter=formatter,
output_path=args.output,
max_examples=args.max_examples,
)