| """ |
| Validation script to test special token integration. |
| |
| Run this to verify that special tokens are correctly integrated into the framework. |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.append(str(Path(__file__).parent.parent)) |
|
|
| def test_special_tokens_defined(): |
| """Test that special tokens are properly defined.""" |
| print("=" * 60) |
| print("TEST 1: Special tokens defined") |
| print("=" * 60) |
| |
| from src.reasoning.step_data import SPECIAL_TOKENS, SPECIAL_TOKEN_DESCRIPTIONS |
| |
| assert len(SPECIAL_TOKENS) == 15, f"Expected 15 tokens, got {len(SPECIAL_TOKENS)}" |
| assert len(SPECIAL_TOKEN_DESCRIPTIONS) == 15, f"Expected 15 descriptions, got {len(SPECIAL_TOKEN_DESCRIPTIONS)}" |
| |
| print(f"✓ {len(SPECIAL_TOKENS)} special tokens defined") |
| print(f"✓ {len(SPECIAL_TOKEN_DESCRIPTIONS)} descriptions defined") |
| print("\nSpecial tokens:") |
| for token, desc in zip(SPECIAL_TOKENS, SPECIAL_TOKEN_DESCRIPTIONS): |
| print(f" {token:30s} -> {desc}") |
| print() |
|
|
|
|
| def test_step_formatting(): |
| """Test ReasoningStep formatting with special tokens.""" |
| print("=" * 60) |
| print("TEST 2: ReasoningStep formatting") |
| print("=" * 60) |
| |
| from src.reasoning.step_data import ReasoningStep, StepType |
| |
| step = ReasoningStep( |
| step_id=0, |
| step_type=StepType.PERCEPTION, |
| description="I observe three red apples in the bowl", |
| confidence=0.95, |
| dependencies=[] |
| ) |
| |
| formatted = step.format_with_tokens() |
| |
| print("Original step:") |
| print(f" ID: {step.step_id}") |
| print(f" Type: {step.step_type.value}") |
| print(f" Description: {step.description}") |
| print(f" Confidence: {step.confidence}") |
| |
| print("\nFormatted with special tokens:") |
| print(f" {formatted}") |
| |
| |
| assert "<|step_start|>" in formatted |
| assert "<|step_end|>" in formatted |
| assert "<|step_type|>" in formatted |
| assert "<|description_start|>" in formatted |
| assert "<|description_end|>" in formatted |
| assert "<|confidence_start|>" in formatted |
| assert "<|confidence_end|>" in formatted |
| assert "ки" in formatted |
| |
| print("\n✓ All required tokens present in formatted step") |
| print() |
|
|
|
|
| def test_chain_formatting(): |
| """Test ReasoningChain formatting with special tokens.""" |
| print("=" * 60) |
| print("TEST 3: ReasoningChain formatting") |
| print("=" * 60) |
| |
| from src.reasoning.step_data import ReasoningStep, ReasoningChain, StepType |
| |
| steps = [ |
| ReasoningStep( |
| step_id=0, |
| step_type=StepType.PERCEPTION, |
| description="I see red apples in a bowl", |
| confidence=0.95, |
| dependencies=[] |
| ), |
| ReasoningStep( |
| step_id=1, |
| step_type=StepType.COUNTING, |
| description="Counting: 1, 2, 3 apples", |
| confidence=0.92, |
| dependencies=[0] |
| ), |
| ] |
| |
| chain = ReasoningChain( |
| chain_id="test_001", |
| image_path="test.jpg", |
| prompt="How many red apples are there?", |
| steps=steps, |
| final_answer="There are 3 red apples", |
| is_correct=True |
| ) |
| |
| formatted = chain.format_with_tokens() |
| |
| print("Original chain:") |
| print(f" Prompt: {chain.prompt}") |
| print(f" Steps: {len(chain.steps)}") |
| print(f" Answer: {chain.final_answer}") |
| |
| print("\nFormatted with special tokens:") |
| print(formatted) |
| |
| |
| assert "<|reasoning_start|>" in formatted |
| assert "<|reasoning_end|>" in formatted |
| assert "<|answer_start|>" in formatted |
| assert "<|answer_end|>" in formatted |
| assert formatted.count("<|step_start|>") == 2 |
| assert formatted.count("<|step_end|>") == 2 |
| assert formatted.count("ки") == 2 |
| assert "<|depends_on|>0" in formatted |
| |
| print("\n✓ Chain structure valid") |
| print(f"✓ Found 2 steps with dependencies") |
| print() |
|
|
|
|
| def test_dataset_integration(): |
| """Test that StepDataset returns formatted data.""" |
| print("=" * 60) |
| print("TEST 4: StepDataset integration") |
| print("=" * 60) |
| |
| from src.reasoning.step_data import ReasoningStep, ReasoningChain, StepType, StepDataset |
| from transformers import AutoTokenizer |
| |
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| except Exception as e: |
| print(f"⚠ Could not load tokenizer for testing: {e}") |
| print(" This is OK - just verifying API") |
| return |
| |
| |
| steps = [ |
| ReasoningStep( |
| step_id=0, |
| step_type=StepType.PERCEPTION, |
| description="Test step", |
| confidence=0.9, |
| ) |
| ] |
| |
| chain = ReasoningChain( |
| chain_id="test_001", |
| image_path="test.jpg", |
| prompt="Test prompt", |
| steps=steps, |
| final_answer="Test answer", |
| ) |
| |
| |
| dataset = StepDataset( |
| reasoning_chains=[chain], |
| tokenizer=tokenizer, |
| max_steps=10, |
| ) |
| |
| |
| item = dataset[0] |
| |
| print("Dataset item keys:") |
| for key in item.keys(): |
| print(f" {key}") |
| |
| |
| assert 'formatted_input_ids' in item |
| assert 'formatted_attention_mask' in item |
| assert 'formatted_text' in item |
| |
| print("\n✓ Dataset returns formatted data") |
| print(f"✓ Formatted text length: {len(item['formatted_text'])} chars") |
| print(f" Preview: {item['formatted_text'][:100]}...") |
| print() |
|
|
|
|
| def test_trainer_methods(): |
| """Test that StepLevelCoTTrainer has special token methods.""" |
| print("=" * 60) |
| print("TEST 5: Trainer special token methods") |
| print("=" * 60) |
| |
| from src.reasoning.step_level_cot import StepLevelCoTTrainer |
| import inspect |
| |
| |
| methods = [m for m in dir(StepLevelCoTTrainer) if not m.startswith('_')] |
| private_methods = [m for m in dir(StepLevelCoTTrainer) if m.startswith('_') and not m.startswith('__')] |
| |
| assert '_add_special_tokens' in private_methods |
| assert '_initialize_new_token_embeddings' in private_methods |
| |
| print("✓ Trainer has _add_special_tokens method") |
| print("✓ Trainer has _initialize_new_token_embeddings method") |
| |
| |
| add_tokens_sig = inspect.signature(StepLevelCoTTrainer._add_special_tokens) |
| init_embeddings_sig = inspect.signature(StepLevelCoTTrainer._initialize_new_token_embeddings) |
| |
| print(f"\n_add_special_tokens{add_tokens_sig}") |
| print(f"_initialize_new_token_embeddings{init_embeddings_sig}") |
| print() |
|
|
|
|
| def test_utility_scripts(): |
| """Test that utility scripts exist.""" |
| print("=" * 60) |
| print("TEST 6: Utility scripts") |
| print("=" * 60) |
| |
| utils_dir = Path(__file__).parent |
| |
| add_tokens_script = utils_dir / "add_special_tokens.py" |
| example_script = utils_dir / "special_token_usage_example.py" |
| |
| assert add_tokens_script.exists(), f"Missing {add_tokens_script}" |
| assert example_script.exists(), f"Missing {example_script}" |
| |
| print(f"✓ Found {add_tokens_script.name}") |
| print(f"✓ Found {example_script.name}") |
| |
| |
| with open(add_tokens_script) as f: |
| content = f.read() |
| assert "def add_special_tokens_to_model" in content |
| assert "if __name__ == \"__main__\":" in content |
| |
| with open(example_script) as f: |
| content = f.read() |
| assert "class SpecialTokenParser" in content |
| assert "class SpecialTokenGenerator" in content |
| |
| print("✓ Scripts have required functions/classes") |
| print() |
|
|
|
|
| def run_all_tests(): |
| """Run all validation tests.""" |
| print("\n" + "=" * 60) |
| print("SPECIAL TOKENS INTEGRATION VALIDATION") |
| print("=" * 60 + "\n") |
| |
| tests = [ |
| ("Special tokens defined", test_special_tokens_defined), |
| ("ReasoningStep formatting", test_step_formatting), |
| ("ReasoningChain formatting", test_chain_formatting), |
| ("StepDataset integration", test_dataset_integration), |
| ("Trainer methods", test_trainer_methods), |
| ("Utility scripts", test_utility_scripts), |
| ] |
| |
| passed = 0 |
| failed = 0 |
| |
| for test_name, test_func in tests: |
| try: |
| test_func() |
| passed += 1 |
| except Exception as e: |
| print(f"\n❌ TEST FAILED: {test_name}") |
| print(f" Error: {e}") |
| import traceback |
| traceback.print_exc() |
| failed += 1 |
| |
| print("\n" + "=" * 60) |
| print("VALIDATION SUMMARY") |
| print("=" * 60) |
| print(f"Passed: {passed}/{len(tests)}") |
| print(f"Failed: {failed}/{len(tests)}") |
| |
| if failed == 0: |
| print("\n✅ All tests passed! Special tokens integration is working correctly.") |
| else: |
| print(f"\n⚠ {failed} test(s) failed. Please review the errors above.") |
| |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| run_all_tests() |
|
|