|
|
| """
|
| Quick test script to verify CNN training functionality
|
| """
|
|
|
| import sys
|
| import os
|
| sys.path.append('.')
|
|
|
| def test_cnn_import():
|
| """Test if CNN module imports correctly"""
|
| print("π§ͺ Testing CNN module import...")
|
| try:
|
| from modules.cnn_deblurring import CNNDeblurModel
|
| print("β
CNN module imported successfully")
|
| return True
|
| except Exception as e:
|
| print(f"β CNN import failed: {e}")
|
| return False
|
|
|
| def test_model_creation():
|
| """Test model creation"""
|
| print("π§ͺ Testing model creation...")
|
| try:
|
| from modules.cnn_deblurring import CNNDeblurModel
|
| model = CNNDeblurModel()
|
| model.build_model()
|
| print("β
Model created successfully")
|
| print(f" Model input shape: {model.input_shape}")
|
| print(f" Model built: {model.model is not None}")
|
| return True
|
| except Exception as e:
|
| print(f"β Model creation failed: {e}")
|
| return False
|
|
|
| def test_user_images():
|
| """Test user images detection"""
|
| print("π§ͺ Testing user images detection...")
|
| try:
|
| dataset_path = "data/training_dataset"
|
| if os.path.exists(dataset_path):
|
| valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
|
| user_images = [f for f in os.listdir(dataset_path)
|
| if any(f.lower().endswith(ext) for ext in valid_extensions)]
|
| print(f"β
Found {len(user_images)} user training images")
|
| for img in user_images[:5]:
|
| print(f" - {img}")
|
| if len(user_images) > 5:
|
| print(f" ... and {len(user_images) - 5} more")
|
| return True
|
| else:
|
| print("β οΈ Training dataset directory not found")
|
| return False
|
| except Exception as e:
|
| print(f"β User images test failed: {e}")
|
| return False
|
|
|
| def test_quick_dataset_creation():
|
| """Test creating a small dataset"""
|
| print("π§ͺ Testing quick dataset creation...")
|
| try:
|
| from modules.cnn_deblurring import CNNDeblurModel
|
|
|
| model = CNNDeblurModel()
|
| trainer = model
|
|
|
|
|
| print(" Creating 10 sample dataset...")
|
| blurred, clean = trainer.create_training_dataset(num_samples=10, save_dataset=False)
|
|
|
| print(f"β
Dataset created successfully")
|
| print(f" Blurred images shape: {blurred.shape}")
|
| print(f" Clean images shape: {clean.shape}")
|
| return True
|
| except Exception as e:
|
| print(f"β Dataset creation failed: {e}")
|
| return False
|
|
|
| def main():
|
| """Run all tests"""
|
| print("π CNN Training Test Suite")
|
| print("=" * 40)
|
|
|
| tests = [
|
| ("CNN Import", test_cnn_import),
|
| ("Model Creation", test_model_creation),
|
| ("User Images", test_user_images),
|
| ("Dataset Creation", test_quick_dataset_creation)
|
| ]
|
|
|
| passed = 0
|
| total = len(tests)
|
|
|
| for name, test_func in tests:
|
| print(f"\nπ {name}")
|
| print("-" * 20)
|
| if test_func():
|
| passed += 1
|
| print()
|
|
|
| print("=" * 40)
|
| print(f"π Test Results: {passed}/{total} tests passed")
|
|
|
| if passed == total:
|
| print("π All tests passed! Training should work correctly.")
|
| print("\nπ‘ Next steps:")
|
| print(" 1. Go to your Streamlit app: http://localhost:8503")
|
| print(" 2. Look for 'π€ CNN Model Management' in sidebar")
|
| print(" 3. Click 'β‘ Quick Train' to start training")
|
| else:
|
| print("β Some tests failed. Please check the errors above.")
|
|
|
| if __name__ == "__main__":
|
| main() |