Spaces:
Running
Running
File size: 3,881 Bytes
40fd629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
#!/usr/bin/env python3
"""
Test script to verify trainer selection logic
"""
import sys
import os
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / "config"))
def test_config_trainer_type():
"""Test that config files have the correct trainer_type"""
print("Testing config trainer_type...")
# Test base config
from train_smollm3 import SmolLM3Config
base_config = SmolLM3Config()
assert base_config.trainer_type == "sft", f"Base config should have trainer_type='sft', got {base_config.trainer_type}"
print("β
Base config trainer_type: sft")
# Test DPO config
from train_smollm3_dpo import SmolLM3DPOConfig
dpo_config = SmolLM3DPOConfig()
assert dpo_config.trainer_type == "dpo", f"DPO config should have trainer_type='dpo', got {dpo_config.trainer_type}"
print("β
DPO config trainer_type: dpo")
return True
def test_trainer_classes_exist():
"""Test that trainer classes exist in the trainer module"""
print("Testing trainer class existence...")
try:
# Add src to path
sys.path.insert(0, str(project_root / "src"))
# Import trainer module
import trainer
print("β
Trainer module imported successfully")
# Check if classes exist
assert hasattr(trainer, 'SmolLM3Trainer'), "SmolLM3Trainer class not found"
assert hasattr(trainer, 'SmolLM3DPOTrainer'), "SmolLM3DPOTrainer class not found"
print("β
Both trainer classes exist")
return True
except Exception as e:
print(f"β Failed to check trainer classes: {e}")
return False
def test_config_inheritance():
"""Test that DPO config properly inherits from base config"""
print("Testing config inheritance...")
try:
from train_smollm3 import SmolLM3Config
from train_smollm3_dpo import SmolLM3DPOConfig
# Test that DPO config inherits from base config
base_config = SmolLM3Config()
dpo_config = SmolLM3DPOConfig()
# Check that DPO config has all base config fields
base_fields = set(base_config.__dict__.keys())
dpo_fields = set(dpo_config.__dict__.keys())
# DPO config should have all base fields plus DPO-specific ones
assert base_fields.issubset(dpo_fields), "DPO config missing base config fields"
print("β
DPO config properly inherits from base config")
# Check that trainer_type is overridden correctly
assert dpo_config.trainer_type == "dpo", "DPO config should have trainer_type='dpo'"
assert base_config.trainer_type == "sft", "Base config should have trainer_type='sft'"
print("β
Trainer type inheritance works correctly")
return True
except Exception as e:
print(f"β Failed to test config inheritance: {e}")
return False
def main():
"""Run all tests"""
print("π§ͺ Testing Trainer Selection Implementation")
print("=" * 50)
tests = [
test_config_trainer_type,
test_trainer_classes_exist,
test_config_inheritance,
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
else:
print(f"β Test {test.__name__} failed")
except Exception as e:
print(f"β Test {test.__name__} failed with exception: {e}")
print("=" * 50)
print(f"Tests passed: {passed}/{total}")
if passed == total:
print("π All tests passed!")
return 0
else:
print("β Some tests failed!")
return 1
if __name__ == "__main__":
exit(main()) |