Spaces:
Sleeping
Sleeping
| """ | |
| Continuous learning script for model improvement | |
| Fine-tunes existing model with new corrected samples | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torchvision import models | |
| from pathlib import Path | |
| import shutil | |
| from datetime import datetime | |
| import json | |
| from .train import WasteDataset, get_transforms, validate, CATEGORIES, CONFIG | |
| def get_model_version(): | |
| """Get next model version number""" | |
| model_dir = Path(CONFIG['model_dir']) | |
| existing_versions = list(model_dir.glob('model_v*.pth')) | |
| if not existing_versions: | |
| return 1 | |
| versions = [int(p.stem.split('_v')[1]) for p in existing_versions] | |
| return max(versions) + 1 | |
| def prepare_retraining_data(): | |
| """Organize retraining data into proper structure""" | |
| retraining_dir = Path('ml/data/retraining') | |
| processed_dir = Path(CONFIG['data_dir']) | |
| if not retraining_dir.exists(): | |
| print("No retraining data found") | |
| return 0 | |
| # Count new samples | |
| new_samples = 0 | |
| for category in CATEGORIES: | |
| category_dir = retraining_dir / category | |
| if category_dir.exists(): | |
| images = list(category_dir.glob('*.jpg')) + list(category_dir.glob('*.png')) | |
| new_samples += len(images) | |
| # Copy to training set | |
| target_dir = processed_dir / 'train' / category | |
| target_dir.mkdir(parents=True, exist_ok=True) | |
| for img_path in images: | |
| target_path = target_dir / f"retrain_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{img_path.name}" | |
| shutil.copy(img_path, target_path) | |
| print(f"Added {new_samples} new samples to training set") | |
| return new_samples | |
| def retrain_model(base_model_path='ml/models/best_model.pth', | |
| num_epochs=10, | |
| learning_rate=0.0001): | |
| """ | |
| Fine-tune existing model with new data | |
| Uses lower learning rate for incremental learning | |
| """ | |
| print("Starting retraining process...") | |
| # Prepare new data | |
| new_samples = prepare_retraining_data() | |
| if new_samples == 0: | |
| print("No new samples to train on") | |
| return None | |
| # Setup device | |
| device = torch.device(CONFIG['device']) | |
| print(f"Using device: {device}") | |
| # Load base model | |
| checkpoint = torch.load(base_model_path, map_location=device) | |
| model = models.efficientnet_b0(pretrained=False) | |
| num_features = model.classifier[1].in_features | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3), | |
| nn.Linear(num_features, CONFIG['num_classes']) | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.to(device) | |
| print(f"Loaded base model with accuracy: {checkpoint['accuracy']:.2f}%") | |
| # Create datasets with updated data | |
| train_dataset = WasteDataset( | |
| CONFIG['data_dir'], | |
| split='train', | |
| transform=get_transforms('train') | |
| ) | |
| val_dataset = WasteDataset( | |
| CONFIG['data_dir'], | |
| split='val', | |
| transform=get_transforms('val') | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=CONFIG['batch_size'], | |
| shuffle=True, | |
| num_workers=4 | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=CONFIG['batch_size'], | |
| shuffle=False, | |
| num_workers=4 | |
| ) | |
| # Setup training | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| best_acc = checkpoint['accuracy'] | |
| improvement_threshold = 1.0 # Must improve by at least 1% | |
| # Fine-tuning loop | |
| for epoch in range(num_epochs): | |
| print(f"\nRetraining Epoch {epoch+1}/{num_epochs}") | |
| print("-" * 50) | |
| # Train | |
| model.train() | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| # Validate | |
| val_loss, val_acc, f1_macro, f1_weighted, val_preds, val_labels = validate( | |
| model, val_loader, criterion, device | |
| ) | |
| print(f"Val Acc: {val_acc:.2f}% | F1 Macro: {f1_macro:.4f}") | |
| # Check improvement | |
| if val_acc > best_acc: | |
| improvement = val_acc - best_acc | |
| best_acc = val_acc | |
| # Save improved model | |
| version = get_model_version() | |
| new_model_path = f"{CONFIG['model_dir']}/model_v{version}.pth" | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'accuracy': val_acc, | |
| 'f1_macro': f1_macro, | |
| 'f1_weighted': f1_weighted, | |
| 'categories': CATEGORIES, | |
| 'config': CONFIG, | |
| 'base_model': base_model_path, | |
| 'new_samples': new_samples, | |
| 'improvement': improvement, | |
| 'retrain_date': datetime.now().isoformat() | |
| }, new_model_path) | |
| print(f"✓ Improved model saved as v{version} (+{improvement:.2f}%)") | |
| # If significant improvement, promote to production | |
| if improvement >= improvement_threshold: | |
| production_path = f"{CONFIG['model_dir']}/best_model.pth" | |
| # Backup old production model | |
| if Path(production_path).exists(): | |
| backup_path = f"{CONFIG['model_dir']}/best_model_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth" | |
| shutil.copy(production_path, backup_path) | |
| # Promote new model | |
| shutil.copy(new_model_path, production_path) | |
| print(f"✓ Model promoted to production!") | |
| # Log retraining event | |
| log_retraining_event(version, val_acc, improvement, new_samples) | |
| # Clean up retraining directory | |
| retraining_dir = Path('ml/data/retraining') | |
| archive_dir = Path('ml/data/retraining_archive') / datetime.now().strftime('%Y%m%d_%H%M%S') | |
| archive_dir.mkdir(parents=True, exist_ok=True) | |
| for category in CATEGORIES: | |
| category_dir = retraining_dir / category | |
| if category_dir.exists(): | |
| shutil.move(str(category_dir), str(archive_dir / category)) | |
| print(f"\nRetraining complete! Final accuracy: {best_acc:.2f}%") | |
| return model | |
| def log_retraining_event(version, accuracy, improvement, new_samples): | |
| """Log retraining events for monitoring""" | |
| log_file = Path(CONFIG['model_dir']) / 'retraining_log.json' | |
| event = { | |
| 'version': version, | |
| 'timestamp': datetime.now().isoformat(), | |
| 'accuracy': accuracy, | |
| 'improvement': improvement, | |
| 'new_samples': new_samples | |
| } | |
| # Load existing log | |
| if log_file.exists(): | |
| with open(log_file, 'r') as f: | |
| log = json.load(f) | |
| else: | |
| log = [] | |
| log.append(event) | |
| # Save updated log | |
| with open(log_file, 'w') as f: | |
| json.dump(log, f, indent=2) | |
| print(f"Retraining event logged") | |
| if __name__ == "__main__": | |
| retrain_model() | |