dei-model / utils /validate_special_tokens.py
renpas22
Add utils directory
da76488
"""
Validation script to test special token integration.
Run this to verify that special tokens are correctly integrated into the framework.
"""
import sys
from pathlib import Path
# Add parent directory to path
sys.path.append(str(Path(__file__).parent.parent))
def test_special_tokens_defined():
"""Test that special tokens are properly defined."""
print("=" * 60)
print("TEST 1: Special tokens defined")
print("=" * 60)
from src.reasoning.step_data import SPECIAL_TOKENS, SPECIAL_TOKEN_DESCRIPTIONS
assert len(SPECIAL_TOKENS) == 15, f"Expected 15 tokens, got {len(SPECIAL_TOKENS)}"
assert len(SPECIAL_TOKEN_DESCRIPTIONS) == 15, f"Expected 15 descriptions, got {len(SPECIAL_TOKEN_DESCRIPTIONS)}"
print(f"✓ {len(SPECIAL_TOKENS)} special tokens defined")
print(f"✓ {len(SPECIAL_TOKEN_DESCRIPTIONS)} descriptions defined")
print("\nSpecial tokens:")
for token, desc in zip(SPECIAL_TOKENS, SPECIAL_TOKEN_DESCRIPTIONS):
print(f" {token:30s} -> {desc}")
print()
def test_step_formatting():
"""Test ReasoningStep formatting with special tokens."""
print("=" * 60)
print("TEST 2: ReasoningStep formatting")
print("=" * 60)
from src.reasoning.step_data import ReasoningStep, StepType
step = ReasoningStep(
step_id=0,
step_type=StepType.PERCEPTION,
description="I observe three red apples in the bowl",
confidence=0.95,
dependencies=[]
)
formatted = step.format_with_tokens()
print("Original step:")
print(f" ID: {step.step_id}")
print(f" Type: {step.step_type.value}")
print(f" Description: {step.description}")
print(f" Confidence: {step.confidence}")
print("\nFormatted with special tokens:")
print(f" {formatted}")
# Check that all required tokens are present
assert "<|step_start|>" in formatted
assert "<|step_end|>" in formatted
assert "<|step_type|>" in formatted
assert "<|description_start|>" in formatted
assert "<|description_end|>" in formatted
assert "<|confidence_start|>" in formatted
assert "<|confidence_end|>" in formatted
assert "ки" in formatted
print("\n✓ All required tokens present in formatted step")
print()
def test_chain_formatting():
"""Test ReasoningChain formatting with special tokens."""
print("=" * 60)
print("TEST 3: ReasoningChain formatting")
print("=" * 60)
from src.reasoning.step_data import ReasoningStep, ReasoningChain, StepType
steps = [
ReasoningStep(
step_id=0,
step_type=StepType.PERCEPTION,
description="I see red apples in a bowl",
confidence=0.95,
dependencies=[]
),
ReasoningStep(
step_id=1,
step_type=StepType.COUNTING,
description="Counting: 1, 2, 3 apples",
confidence=0.92,
dependencies=[0]
),
]
chain = ReasoningChain(
chain_id="test_001",
image_path="test.jpg",
prompt="How many red apples are there?",
steps=steps,
final_answer="There are 3 red apples",
is_correct=True
)
formatted = chain.format_with_tokens()
print("Original chain:")
print(f" Prompt: {chain.prompt}")
print(f" Steps: {len(chain.steps)}")
print(f" Answer: {chain.final_answer}")
print("\nFormatted with special tokens:")
print(formatted)
# Check structure
assert "<|reasoning_start|>" in formatted
assert "<|reasoning_end|>" in formatted
assert "<|answer_start|>" in formatted
assert "<|answer_end|>" in formatted
assert formatted.count("<|step_start|>") == 2
assert formatted.count("<|step_end|>") == 2
assert formatted.count("ки") == 2
assert "<|depends_on|>0" in formatted # Second step depends on first
print("\n✓ Chain structure valid")
print(f"✓ Found 2 steps with dependencies")
print()
def test_dataset_integration():
"""Test that StepDataset returns formatted data."""
print("=" * 60)
print("TEST 4: StepDataset integration")
print("=" * 60)
from src.reasoning.step_data import ReasoningStep, ReasoningChain, StepType, StepDataset
from transformers import AutoTokenizer
# Create dummy tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use GPT2 as dummy
except Exception as e:
print(f"⚠ Could not load tokenizer for testing: {e}")
print(" This is OK - just verifying API")
return
# Create dummy chain
steps = [
ReasoningStep(
step_id=0,
step_type=StepType.PERCEPTION,
description="Test step",
confidence=0.9,
)
]
chain = ReasoningChain(
chain_id="test_001",
image_path="test.jpg",
prompt="Test prompt",
steps=steps,
final_answer="Test answer",
)
# Create dataset
dataset = StepDataset(
reasoning_chains=[chain],
tokenizer=tokenizer,
max_steps=10,
)
# Get item
item = dataset[0]
print("Dataset item keys:")
for key in item.keys():
print(f" {key}")
# Check that formatted fields are present
assert 'formatted_input_ids' in item
assert 'formatted_attention_mask' in item
assert 'formatted_text' in item
print("\n✓ Dataset returns formatted data")
print(f"✓ Formatted text length: {len(item['formatted_text'])} chars")
print(f" Preview: {item['formatted_text'][:100]}...")
print()
def test_trainer_methods():
"""Test that StepLevelCoTTrainer has special token methods."""
print("=" * 60)
print("TEST 5: Trainer special token methods")
print("=" * 60)
from src.reasoning.step_level_cot import StepLevelCoTTrainer
import inspect
# Check methods exist
methods = [m for m in dir(StepLevelCoTTrainer) if not m.startswith('_')]
private_methods = [m for m in dir(StepLevelCoTTrainer) if m.startswith('_') and not m.startswith('__')]
assert '_add_special_tokens' in private_methods
assert '_initialize_new_token_embeddings' in private_methods
print("✓ Trainer has _add_special_tokens method")
print("✓ Trainer has _initialize_new_token_embeddings method")
# Check method signatures
add_tokens_sig = inspect.signature(StepLevelCoTTrainer._add_special_tokens)
init_embeddings_sig = inspect.signature(StepLevelCoTTrainer._initialize_new_token_embeddings)
print(f"\n_add_special_tokens{add_tokens_sig}")
print(f"_initialize_new_token_embeddings{init_embeddings_sig}")
print()
def test_utility_scripts():
"""Test that utility scripts exist."""
print("=" * 60)
print("TEST 6: Utility scripts")
print("=" * 60)
utils_dir = Path(__file__).parent
add_tokens_script = utils_dir / "add_special_tokens.py"
example_script = utils_dir / "special_token_usage_example.py"
assert add_tokens_script.exists(), f"Missing {add_tokens_script}"
assert example_script.exists(), f"Missing {example_script}"
print(f"✓ Found {add_tokens_script.name}")
print(f"✓ Found {example_script.name}")
# Check that scripts have main functions
with open(add_tokens_script) as f:
content = f.read()
assert "def add_special_tokens_to_model" in content
assert "if __name__ == \"__main__\":" in content
with open(example_script) as f:
content = f.read()
assert "class SpecialTokenParser" in content
assert "class SpecialTokenGenerator" in content
print("✓ Scripts have required functions/classes")
print()
def run_all_tests():
"""Run all validation tests."""
print("\n" + "=" * 60)
print("SPECIAL TOKENS INTEGRATION VALIDATION")
print("=" * 60 + "\n")
tests = [
("Special tokens defined", test_special_tokens_defined),
("ReasoningStep formatting", test_step_formatting),
("ReasoningChain formatting", test_chain_formatting),
("StepDataset integration", test_dataset_integration),
("Trainer methods", test_trainer_methods),
("Utility scripts", test_utility_scripts),
]
passed = 0
failed = 0
for test_name, test_func in tests:
try:
test_func()
passed += 1
except Exception as e:
print(f"\n❌ TEST FAILED: {test_name}")
print(f" Error: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 60)
print("VALIDATION SUMMARY")
print("=" * 60)
print(f"Passed: {passed}/{len(tests)}")
print(f"Failed: {failed}/{len(tests)}")
if failed == 0:
print("\n✅ All tests passed! Special tokens integration is working correctly.")
else:
print(f"\n⚠ {failed} test(s) failed. Please review the errors above.")
print()
if __name__ == "__main__":
run_all_tests()