ganeshkumar383's picture
Upload 27 files (#2)
ecc16d3 verified
"""
Quick CNN Training Script
========================
Simple script to quickly train the CNN model for the Image Deblurring application.
"""
import os
import sys
# Add current directory to path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
def main():
print("🎯 AI Image Deblurring - CNN Model Training")
print("=" * 50)
print()
# Import after path setup
from modules.cnn_deblurring import CNNDeblurModel
# Check if model already exists
model_path = "models/cnn_deblur_model.h5"
if os.path.exists(model_path):
print("⚠️ A trained model already exists!")
print(f" Location: {model_path}")
choice = input("\nDo you want to:\n (1) Keep existing model\n (2) Train new model (overwrites existing)\n\nChoice (1/2): ").strip()
if choice == "1":
print("βœ… Keeping existing model. You can start using the application!")
return
elif choice != "2":
print("❌ Invalid choice. Exiting.")
return
print("πŸš€ Starting CNN Model Training...")
print()
# Choose training mode
print("Training Options:")
print(" 1. Quick Training (Recommended for testing)")
print(" β€’ 500 samples, 10 epochs")
print(" β€’ Training time: ~10-15 minutes")
print(" β€’ Good for initial testing")
print()
print(" 2. Standard Training")
print(" β€’ 1000 samples, 20 epochs")
print(" β€’ Training time: ~20-30 minutes")
print(" β€’ Balanced quality and time")
print()
print(" 3. Full Training")
print(" β€’ 2000 samples, 30 epochs")
print(" β€’ Training time: ~45-60 minutes")
print(" β€’ Best quality results")
while True:
choice = input("\nSelect training mode (1/2/3): ").strip()
if choice == "1":
samples, epochs = 500, 10
break
elif choice == "2":
samples, epochs = 1000, 20
break
elif choice == "3":
samples, epochs = 2000, 30
break
else:
print("❌ Invalid choice. Please enter 1, 2, or 3.")
print(f"\n🎯 Training Configuration:")
print(f" Samples: {samples}")
print(f" Epochs: {epochs}")
print(f" Model will be saved to: {model_path}")
print()
# Confirm training
confirm = input("Start training? (y/N): ").strip().lower()
if confirm != 'y':
print("❌ Training cancelled.")
return
try:
# Create model and train
print("\nπŸ—οΈ Initializing CNN model...")
model = CNNDeblurModel()
print("πŸ“Š Starting training process...")
print(" This will:")
print(" 1. Create synthetic blur dataset")
print(" 2. Build U-Net CNN architecture")
print(" 3. Train the model with early stopping")
print(" 4. Save the trained model")
print()
success = model.train_model(
epochs=epochs,
batch_size=16,
validation_split=0.2,
use_existing_dataset=True,
num_training_samples=samples
)
if success:
print("\nπŸŽ‰ Training Completed Successfully!")
print("=" * 40)
print(f"βœ… Model saved to: {model_path}")
print("βœ… Training dataset created and saved")
# Test the model
print("\nπŸ§ͺ Testing trained model...")
metrics = model.evaluate_model()
if metrics:
print("πŸ“Š Model Performance:")
print(f" Loss: {metrics['loss']:.4f}")
print(f" Mean Absolute Error: {metrics['mae']:.4f}")
print(f" Mean Squared Error: {metrics['mse']:.4f}")
if metrics['loss'] < 0.05:
print("🌟 Excellent! Your model is ready for high-quality deblurring!")
elif metrics['loss'] < 0.1:
print("πŸ‘ Good! Your model will provide decent deblurring results.")
else:
print("⚠️ Model trained but may need more training for optimal results.")
print("\nπŸš€ Next Steps:")
print(" 1. Run the main application: streamlit run streamlit_app.py")
print(" 2. Upload a blurry image")
print(" 3. Select 'CNN Enhancement' method")
print(" 4. Enjoy high-quality AI deblurring!")
else:
print("\n❌ Training Failed!")
print(" Check the error messages above for details.")
print(" You can still use other enhancement methods in the application.")
except KeyboardInterrupt:
print("\n⚠️ Training interrupted by user.")
print(" Partial progress may be saved.")
except Exception as e:
print(f"\n❌ Training error: {e}")
print(" You can still use traditional enhancement methods.")
if __name__ == "__main__":
main()