rae-training / src /rae_tokenizer_utils.py
TrueV1sion123's picture
Upload src/rae_tokenizer_utils.py with huggingface_hub
9030cc5 verified
"""
RAE Tokenizer Utilities
═══════════════════════════════════════════════════════════════
Phase-aware tokenization for RAE training data.
Handles the special structure of RAE responses where XML-style
phase tags delineate cognitive phases. Ensures proper tokenization
of phase boundaries and provides utilities for phase-level analysis.
═══════════════════════════════════════════════════════════════
"""
from typing import Optional
import re
PHASE_TAGS = {
"saturation": ("<SATURATION>", "</SATURATION>"),
"abstraction": ("<ABSTRACTION>", "</ABSTRACTION>"),
"descent": ("<DESCENT>", "</DESCENT>"),
"integration": ("<INTEGRATION>", "</INTEGRATION>"),
}
ALL_TAGS = []
for open_tag, close_tag in PHASE_TAGS.values():
ALL_TAGS.extend([open_tag, close_tag])
def add_rae_tokens(tokenizer):
"""
Add RAE phase tags as special tokens to the tokenizer.
This ensures phase boundaries are tokenized as single tokens
rather than being split across subwords, which makes phase
detection much more reliable during loss computation.
"""
special_tokens = {"additional_special_tokens": ALL_TAGS}
num_added = tokenizer.add_special_tokens(special_tokens)
if num_added > 0:
print(f" Added {num_added} RAE phase tokens to tokenizer")
return tokenizer, num_added
def extract_phases(text: str) -> dict[str, str]:
"""Extract phase content from RAE-structured text."""
phases = {}
for phase_name, (open_tag, close_tag) in PHASE_TAGS.items():
pattern = re.escape(open_tag) + r"(.*?)" + re.escape(close_tag)
match = re.search(pattern, text, re.DOTALL)
phases[phase_name] = match.group(1).strip() if match else ""
return phases
def validate_rae_response(text: str) -> dict:
"""
Validate that a response contains proper RAE structure.
Returns a report with:
- is_valid: bool
- phases_found: list of phase names found
- phases_missing: list of phase names missing
- compression_ratio: abstraction_len / saturation_len
- warnings: list of potential issues
"""
phases = extract_phases(text)
found = [name for name, content in phases.items() if content]
missing = [name for name, content in phases.items() if not content]
warnings = []
# Check phase ordering
if found:
expected_order = ["saturation", "abstraction", "descent", "integration"]
found_order = [p for p in expected_order if p in found]
if found_order != [p for p in found if p in expected_order]:
warnings.append("Phases appear out of order")
# Check compression
compression_ratio = None
sat_len = len(phases.get("saturation", "").split())
abs_len = len(phases.get("abstraction", "").split())
if sat_len > 0:
compression_ratio = abs_len / sat_len
if compression_ratio > 1.0:
warnings.append(f"Abstraction is LONGER than Saturation (ratio={compression_ratio:.2f})")
# Check for degenerate phases
for phase_name, content in phases.items():
word_count = len(content.split())
if content and word_count < 10:
warnings.append(f"{phase_name} is very short ({word_count} words)")
if content and word_count > 1000:
warnings.append(f"{phase_name} is very long ({word_count} words)")
return {
"is_valid": len(found) == 4 and len(warnings) == 0,
"phases_found": found,
"phases_missing": missing,
"phase_lengths": {name: len(content.split()) for name, content in phases.items()},
"compression_ratio": compression_ratio,
"warnings": warnings,
}
def format_rae_chat(
system_prompt: str,
user_message: str,
phases: dict[str, str],
tokenizer=None,
) -> str:
"""
Format RAE phases into a chat-template-ready message.
If tokenizer is provided, applies the chat template.
Otherwise returns raw message list.
"""
assistant_content = ""
for phase_name in ["saturation", "abstraction", "descent", "integration"]:
open_tag, close_tag = PHASE_TAGS[phase_name]
content = phases.get(phase_name, "")
assistant_content += f"{open_tag}\n{content}\n{close_tag}\n\n"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_content.strip()},
]
if tokenizer:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
return messages