kolam-ai-generator / test_fixes.py
Rishab7310's picture
Upload 7 files
7a4e326 verified
"""
Test script to verify all three problems are fixed.
"""
import torch
import numpy as np
import sys
import os
from pathlib import Path
def test_problem_1_notebook_imports():
"""Test Problem 1: Notebook import path fixes."""
print("πŸ”§ Testing Problem 1: Notebook Import Paths...")
try:
# Test the fixed import path
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'models'))
from gan_generator import KolamGenerator
from gan_discriminator import KolamDiscriminator
from cnn_feature_extractor import KolamFeatureExtractor
print("βœ… Problem 1 FIXED: All imports work correctly!")
return True
except Exception as e:
print(f"❌ Problem 1 NOT FIXED: {e}")
return False
def test_problem_2_discriminator():
"""Test Problem 2: Discriminator spectral norm fix."""
print("\nπŸ”§ Testing Problem 2: Discriminator Spectral Norm...")
try:
from models.gan_discriminator import KolamDiscriminator
# Test discriminator creation and forward pass
discriminator = KolamDiscriminator(input_channels=1, image_size=64, use_spectral_norm=True)
discriminator.eval()
# Test with proper input size
test_input = torch.randn(2, 1, 64, 64)
with torch.no_grad():
output = discriminator(test_input)
print(f"βœ… Problem 2 FIXED: Discriminator works! Output shape: {output.shape}")
return True
except Exception as e:
print(f"❌ Problem 2 NOT FIXED: {e}")
return False
def test_problem_3_jupyter():
"""Test Problem 3: Jupyter installation and access."""
print("\nπŸ”§ Testing Problem 3: Jupyter Access...")
try:
import jupyter
import jupyterlab
import notebook
print("βœ… Problem 3 FIXED: Jupyter is properly installed!")
print(" - Jupyter Core: Available")
print(" - JupyterLab: Available")
print(" - Notebook: Available")
print(" - Use: python -m jupyter lab")
return True
except Exception as e:
print(f"❌ Problem 3 NOT FIXED: {e}")
return False
def test_complete_workflow():
"""Test the complete workflow after fixes."""
print("\nπŸ”§ Testing Complete Workflow...")
try:
from models.gan_generator import KolamGenerator
from models.gan_discriminator import KolamDiscriminator
from models.cnn_feature_extractor import KolamFeatureExtractor
from utils.image_utils import create_synthetic_kolam
from utils.metrics import KolamDesignMetrics
# Test complete pipeline
print("βœ… All imports successful")
# Test model creation
generator = KolamGenerator(noise_dim=100, feature_dim=128, output_channels=1, image_size=64)
discriminator = KolamDiscriminator(input_channels=1, image_size=64)
feature_extractor = KolamFeatureExtractor(input_channels=1, feature_dim=128)
print("βœ… All models created successfully")
# Test forward passes
noise = torch.randn(2, 100)
real_images = torch.randn(2, 1, 64, 64)
with torch.no_grad():
generated = generator(noise)
features = feature_extractor(real_images)
real_scores = discriminator(real_images)
fake_scores = discriminator(generated)
print("βœ… All forward passes successful")
print(f" - Generated: {generated.shape}")
print(f" - Features: {features.shape}")
print(f" - Real scores: {real_scores.shape}")
print(f" - Fake scores: {fake_scores.shape}")
# Test utilities
kolam = create_synthetic_kolam(size=(64, 64), complexity='medium')
metrics = KolamDesignMetrics()
quality = metrics.calculate_overall_quality(kolam)
print("βœ… All utilities working")
print(f" - Kolam created: {kolam.shape}")
print(f" - Quality score: {quality['overall_quality']:.3f}")
return True
except Exception as e:
print(f"❌ Complete workflow failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Main test function."""
print("πŸš€ TESTING ALL THREE PROBLEM FIXES")
print("=" * 50)
tests = [
("Problem 1: Notebook Imports", test_problem_1_notebook_imports),
("Problem 2: Discriminator", test_problem_2_discriminator),
("Problem 3: Jupyter Access", test_problem_3_jupyter),
("Complete Workflow", test_complete_workflow)
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
if test_func():
passed += 1
else:
print(f"❌ {test_name} FAILED")
print(f"\nπŸ“Š RESULTS: {passed}/{total} tests passed")
if passed == total:
print("\nπŸŽ‰ ALL THREE PROBLEMS FIXED!")
print("βœ… The Kolam AI Generator is now fully functional!")
print("\nπŸš€ Ready to use:")
print("1. python -m jupyter lab # Start Jupyter Lab")
print("2. python run_demo.py # Run full demonstration")
print("3. python scripts/train_cnn.py # Train models")
else:
print(f"\n⚠️ {total - passed} problems still need fixing")
if __name__ == "__main__":
main()