Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test script to verify trainer selection logic | |
""" | |
import sys | |
import os | |
from pathlib import Path | |
# Add project root to path | |
project_root = Path(__file__).parent.parent | |
sys.path.insert(0, str(project_root)) | |
sys.path.insert(0, str(project_root / "config")) | |
def test_config_trainer_type(): | |
"""Test that config files have the correct trainer_type""" | |
print("Testing config trainer_type...") | |
# Test base config | |
from train_smollm3 import SmolLM3Config | |
base_config = SmolLM3Config() | |
assert base_config.trainer_type == "sft", f"Base config should have trainer_type='sft', got {base_config.trainer_type}" | |
print("β Base config trainer_type: sft") | |
# Test DPO config | |
from train_smollm3_dpo import SmolLM3DPOConfig | |
dpo_config = SmolLM3DPOConfig() | |
assert dpo_config.trainer_type == "dpo", f"DPO config should have trainer_type='dpo', got {dpo_config.trainer_type}" | |
print("β DPO config trainer_type: dpo") | |
return True | |
def test_trainer_classes_exist(): | |
"""Test that trainer classes exist in the trainer module""" | |
print("Testing trainer class existence...") | |
try: | |
# Add src to path | |
sys.path.insert(0, str(project_root / "src")) | |
# Import trainer module | |
import trainer | |
print("β Trainer module imported successfully") | |
# Check if classes exist | |
assert hasattr(trainer, 'SmolLM3Trainer'), "SmolLM3Trainer class not found" | |
assert hasattr(trainer, 'SmolLM3DPOTrainer'), "SmolLM3DPOTrainer class not found" | |
print("β Both trainer classes exist") | |
return True | |
except Exception as e: | |
print(f"β Failed to check trainer classes: {e}") | |
return False | |
def test_config_inheritance(): | |
"""Test that DPO config properly inherits from base config""" | |
print("Testing config inheritance...") | |
try: | |
from train_smollm3 import SmolLM3Config | |
from train_smollm3_dpo import SmolLM3DPOConfig | |
# Test that DPO config inherits from base config | |
base_config = SmolLM3Config() | |
dpo_config = SmolLM3DPOConfig() | |
# Check that DPO config has all base config fields | |
base_fields = set(base_config.__dict__.keys()) | |
dpo_fields = set(dpo_config.__dict__.keys()) | |
# DPO config should have all base fields plus DPO-specific ones | |
assert base_fields.issubset(dpo_fields), "DPO config missing base config fields" | |
print("β DPO config properly inherits from base config") | |
# Check that trainer_type is overridden correctly | |
assert dpo_config.trainer_type == "dpo", "DPO config should have trainer_type='dpo'" | |
assert base_config.trainer_type == "sft", "Base config should have trainer_type='sft'" | |
print("β Trainer type inheritance works correctly") | |
return True | |
except Exception as e: | |
print(f"β Failed to test config inheritance: {e}") | |
return False | |
def main(): | |
"""Run all tests""" | |
print("π§ͺ Testing Trainer Selection Implementation") | |
print("=" * 50) | |
tests = [ | |
test_config_trainer_type, | |
test_trainer_classes_exist, | |
test_config_inheritance, | |
] | |
passed = 0 | |
total = len(tests) | |
for test in tests: | |
try: | |
if test(): | |
passed += 1 | |
else: | |
print(f"β Test {test.__name__} failed") | |
except Exception as e: | |
print(f"β Test {test.__name__} failed with exception: {e}") | |
print("=" * 50) | |
print(f"Tests passed: {passed}/{total}") | |
if passed == total: | |
print("π All tests passed!") | |
return 0 | |
else: | |
print("β Some tests failed!") | |
return 1 | |
if __name__ == "__main__": | |
exit(main()) |