| | |
| | """ |
| | β
OPTIMIZED Food101 + ResNet50 with major speed improvements |
| | β
Mixed precision training (2x faster) |
| | β
Better data loading (persistent workers) |
| | β
Progress bars and better logging |
| | β
Robust error handling and checkpointing |
| | """ |
| |
|
| | import os |
| | import time |
| | import copy |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from tqdm import tqdm |
| | import logging |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | import torchvision |
| | import torchvision.transforms as transforms |
| | from torch.utils.data import DataLoader |
| | from torch.cuda.amp import autocast, GradScaler |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | |
| | |
| | def get_food101_loaders(batch_size=64, num_workers=8): |
| | """Returns optimized train/val/test loaders + class names""" |
| | |
| | |
| | transform_train = transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.RandomCrop((224, 224)), |
| | transforms.RandomHorizontalFlip(p=0.5), |
| | transforms.RandomRotation(15), |
| | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), |
| | transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | transform_test = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | try: |
| | |
| | full_train = torchvision.datasets.Food101( |
| | root='./data', split='train', download=True, transform=transform_train |
| | ) |
| | |
| | |
| | torch.manual_seed(42) |
| | train_size = int(0.9 * len(full_train)) |
| | val_size = len(full_train) - train_size |
| | train_dataset, val_dataset = torch.utils.data.random_split( |
| | full_train, [train_size, val_size] |
| | ) |
| | |
| | |
| | test_dataset = torchvision.datasets.Food101( |
| | root='./data', split='test', download=True, transform=transform_test |
| | ) |
| | |
| | logger.info(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}") |
| | |
| | |
| | train_loader = DataLoader( |
| | train_dataset, batch_size, shuffle=True, num_workers=num_workers, |
| | pin_memory=True, persistent_workers=True, drop_last=True |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, batch_size, shuffle=False, num_workers=num_workers, |
| | pin_memory=True, persistent_workers=True |
| | ) |
| | test_loader = DataLoader( |
| | test_dataset, batch_size, shuffle=False, num_workers=num_workers, |
| | pin_memory=True, persistent_workers=True |
| | ) |
| | |
| | return train_loader, val_loader, test_loader, full_train.classes |
| | |
| | except Exception as e: |
| | logger.error(f"Error loading data: {e}") |
| | raise |
| |
|
| |
|
| | |
| | |
| | |
| | class BasicBlock(nn.Module): |
| | expansion = 1 |
| | def __init__(self, inplanes, planes, stride=1, downsample=None): |
| | super().__init__() |
| | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(planes) |
| | self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(planes) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.downsample = downsample |
| |
|
| | def forward(self, x): |
| | identity = x |
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| | out = self.conv2(out) |
| | out = self.bn2(out) |
| | if self.downsample: identity = self.downsample(x) |
| | out += identity |
| | out = self.relu(out) |
| | return out |
| |
|
| |
|
| | class Bottleneck(nn.Module): |
| | expansion = 4 |
| | def __init__(self, inplanes, planes, stride=1, downsample=None): |
| | super().__init__() |
| | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(planes) |
| | self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(planes) |
| | self.conv3 = nn.Conv2d(planes, planes*self.expansion, 1, bias=False) |
| | self.bn3 = nn.BatchNorm2d(planes*self.expansion) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.downsample = downsample |
| |
|
| | def forward(self, x): |
| | identity = x |
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| | out = self.conv2(out) |
| | out = self.bn2(out) |
| | out = self.relu(out) |
| | out = self.conv3(out) |
| | out = self.bn3(out) |
| | if self.downsample: identity = self.downsample(x) |
| | out += identity |
| | out = self.relu(out) |
| | return out |
| |
|
| |
|
| | class ResNet50(nn.Module): |
| | def __init__(self, num_classes=101): |
| | super().__init__() |
| | self.inplanes = 64 |
| | |
| | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) |
| | self.bn1 = nn.BatchNorm2d(64) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.maxpool = nn.MaxPool2d(3, 2, 1) |
| | |
| | self.layer1 = self._make_layer(Bottleneck, 64, 3) |
| | self.layer2 = self._make_layer(Bottleneck, 128, 4, 2) |
| | self.layer3 = self._make_layer(Bottleneck, 256, 6, 2) |
| | self.layer4 = self._make_layer(Bottleneck, 512, 3, 2) |
| | |
| | self.avgpool = nn.AdaptiveAvgPool2d(1) |
| | self.fc = nn.Linear(512*Bottleneck.expansion, num_classes) |
| | |
| | |
| | self._initialize_weights() |
| | |
| | def _make_layer(self, block, planes, blocks, stride=1): |
| | downsample = None |
| | if stride != 1 or self.inplanes != planes*block.expansion: |
| | downsample = nn.Sequential( |
| | nn.Conv2d(self.inplanes, planes*block.expansion, 1, stride, bias=False), |
| | nn.BatchNorm2d(planes*block.expansion) |
| | ) |
| | |
| | layers = [block(self.inplanes, planes, stride, downsample)] |
| | self.inplanes = planes * block.expansion |
| | for _ in range(1, blocks): |
| | layers.append(block(self.inplanes, planes)) |
| | return nn.Sequential(*layers) |
| | |
| | def _initialize_weights(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| | elif isinstance(m, nn.BatchNorm2d): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| | |
| | def forward(self, x): |
| | x = self.conv1(x) |
| | x = self.bn1(x) |
| | x = self.relu(x) |
| | x = self.maxpool(x) |
| | |
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| | x = self.layer3(x) |
| | x = self.layer4(x) |
| | |
| | x = self.avgpool(x) |
| | x = torch.flatten(x, 1) |
| | x = self.fc(x) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| | def train_model(model, train_loader, val_loader, test_loader, device, num_epochs=100, resume_from=None): |
| | """Optimized training loop with mixed precision and better checkpointing""" |
| | |
| | os.makedirs('./outputs', exist_ok=True) |
| | |
| | criterion = nn.CrossEntropyLoss(label_smoothing=0.1) |
| | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True) |
| | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) |
| | |
| | |
| | scaler = GradScaler() |
| | |
| | best_val_acc = 0.0 |
| | train_losses, val_accuracies, learning_rates = [], [], [] |
| | start_epoch = 0 |
| | |
| | |
| | if resume_from and os.path.exists(resume_from): |
| | logger.info(f"Resuming from {resume_from}") |
| | checkpoint = torch.load(resume_from, map_location=device) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | start_epoch = checkpoint['epoch'] |
| | best_val_acc = checkpoint.get('best_val_accuracy', 0.0) |
| | train_losses = checkpoint.get('train_losses', []) |
| | val_accuracies = checkpoint.get('val_accuracies', []) |
| | learning_rates = checkpoint.get('learning_rates', []) |
| | |
| | logger.info(f"π Starting training from epoch {start_epoch+1} for {num_epochs} total epochs...") |
| | |
| | |
| | total_train_time = 0 |
| | |
| | for epoch in range(start_epoch, num_epochs): |
| | epoch_start = time.time() |
| | |
| | |
| | model.train() |
| | running_loss = 0.0 |
| | train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False) |
| | |
| | for batch_idx, (images, labels) in enumerate(train_pbar): |
| | images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) |
| | |
| | optimizer.zero_grad() |
| | |
| | |
| | with autocast(): |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | |
| | |
| | scaler.scale(loss).backward() |
| | scaler.step(optimizer) |
| | scaler.update() |
| | |
| | running_loss += loss.item() |
| | train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'}) |
| | |
| | avg_train_loss = running_loss / len(train_loader) |
| | train_losses.append(avg_train_loss) |
| | learning_rates.append(optimizer.param_groups[0]['lr']) |
| | |
| | |
| | model.eval() |
| | val_loss = 0.0 |
| | correct = 0 |
| | total = 0 |
| | val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]', leave=False) |
| | |
| | with torch.no_grad(): |
| | for images, labels in val_pbar: |
| | images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) |
| | |
| | with autocast(): |
| | outputs = model(images) |
| | loss = criterion(outputs, labels) |
| | |
| | val_loss += loss.item() |
| | _, predicted = torch.max(outputs, 1) |
| | total += labels.size(0) |
| | correct += (predicted == labels).sum().item() |
| | |
| | val_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'}) |
| | |
| | val_acc = 100. * correct / total |
| | val_accuracies.append(val_acc) |
| | avg_val_loss = val_loss / len(val_loader) |
| | |
| | |
| | is_best = val_acc > best_val_acc |
| | if is_best: |
| | best_val_acc = val_acc |
| | |
| | |
| | if (epoch + 1) % 10 == 0 or is_best or epoch == num_epochs - 1: |
| | checkpoint = { |
| | 'epoch': epoch + 1, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'scaler_state_dict': scaler.state_dict(), |
| | 'best_val_accuracy': best_val_acc, |
| | 'current_val_accuracy': val_acc, |
| | 'train_losses': train_losses, |
| | 'val_accuracies': val_accuracies, |
| | 'learning_rates': learning_rates, |
| | } |
| | |
| | if is_best: |
| | torch.save(checkpoint, './outputs/food101_resnet50_best.pth') |
| | |
| | torch.save(model.state_dict(), './outputs/food101_resnet50_best_weights.pth') |
| | |
| | if (epoch + 1) % 10 == 0: |
| | torch.save(checkpoint, f'./outputs/food101_resnet50_epoch_{epoch+1}.pth') |
| | |
| | scheduler.step() |
| | epoch_time = time.time() - epoch_start |
| | total_train_time += epoch_time |
| | |
| | logger.info(f"Epoch {epoch+1:3d}/{num_epochs} | " |
| | f"Train Loss: {avg_train_loss:.4f} | " |
| | f"Val Loss: {avg_val_loss:.4f} | " |
| | f"Val Acc: {val_acc:.2f}% | " |
| | f"Best: {best_val_acc:.2f}% | " |
| | f"LR: {optimizer.param_groups[0]['lr']:.6f} | " |
| | f"Time: {epoch_time:.1f}s") |
| | |
| | |
| | final_checkpoint = { |
| | 'epoch': num_epochs, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'scaler_state_dict': scaler.state_dict(), |
| | 'final_val_accuracy': val_accuracies[-1], |
| | 'best_val_accuracy': best_val_acc, |
| | 'train_losses': train_losses, |
| | 'val_accuracies': val_accuracies, |
| | 'learning_rates': learning_rates, |
| | 'total_train_time': total_train_time, |
| | } |
| | torch.save(final_checkpoint, './outputs/food101_resnet50_final.pth') |
| | torch.save(model.state_dict(), './outputs/food101_resnet50_final_weights.pth') |
| | |
| | logger.info(f"π Total training time: {total_train_time/3600:.2f} hours") |
| | |
| | |
| | test_acc = evaluate_model(model, test_loader, device, "Test") |
| | logger.info(f"π― Final Test Accuracy: {test_acc:.2f}%") |
| | |
| | |
| | plot_training_curves(train_losses, val_accuracies, learning_rates) |
| | |
| | return best_val_acc, train_losses, val_accuracies |
| |
|
| |
|
| | def evaluate_model(model, test_loader, device, dataset_name="Test"): |
| | """Evaluate model with progress bar""" |
| | model.eval() |
| | correct = 0 |
| | total = 0 |
| | test_pbar = tqdm(test_loader, desc=f'{dataset_name} Evaluation', leave=False) |
| | |
| | with torch.no_grad(): |
| | for images, labels in test_pbar: |
| | images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True) |
| | |
| | with autocast(): |
| | outputs = model(images) |
| | |
| | _, predicted = torch.max(outputs, 1) |
| | total += labels.size(0) |
| | correct += (predicted == labels).sum().item() |
| | |
| | test_pbar.set_postfix({'acc': f'{100.*correct/total:.2f}%'}) |
| | |
| | return 100. * correct / total |
| |
|
| |
|
| | def plot_training_curves(train_losses, val_accuracies, learning_rates): |
| | """Enhanced plotting with more visualizations""" |
| | epochs = np.arange(1, len(train_losses) + 1) |
| | |
| | plt.style.use('default') |
| | fig, axes = plt.subplots(2, 2, figsize=(16, 12)) |
| | fig.suptitle('Food101 ResNet50 Training Analysis', fontsize=16, fontweight='bold') |
| | |
| | |
| | axes[0, 0].plot(epochs, train_losses, 'b-', linewidth=2, alpha=0.8) |
| | axes[0, 0].set_title('Training Loss Over Time', fontweight='bold') |
| | axes[0, 0].set_xlabel('Epoch') |
| | axes[0, 0].set_ylabel('Loss') |
| | axes[0, 0].grid(True, alpha=0.3) |
| | axes[0, 0].set_yscale('log') |
| | |
| | |
| | axes[0, 1].plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8) |
| | axes[0, 1].set_title('Validation Accuracy Over Time', fontweight='bold') |
| | axes[0, 1].set_xlabel('Epoch') |
| | axes[0, 1].set_ylabel('Accuracy (%)') |
| | axes[0, 1].grid(True, alpha=0.3) |
| | axes[0, 1].axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7, |
| | label=f'Best: {max(val_accuracies):.2f}%') |
| | axes[0, 1].legend() |
| | |
| | |
| | axes[1, 0].plot(epochs, learning_rates, 'g-', linewidth=2, alpha=0.8) |
| | axes[1, 0].set_title('Learning Rate Schedule', fontweight='bold') |
| | axes[1, 0].set_xlabel('Epoch') |
| | axes[1, 0].set_ylabel('Learning Rate') |
| | axes[1, 0].grid(True, alpha=0.3) |
| | axes[1, 0].set_yscale('log') |
| | |
| | |
| | ax_combined = axes[1, 1] |
| | ax_combined.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2, alpha=0.8) |
| | ax_combined.set_xlabel('Epoch') |
| | ax_combined.set_ylabel('Loss', color='b') |
| | ax_combined.tick_params(axis='y', labelcolor='b') |
| | ax_combined.set_yscale('log') |
| | |
| | ax2 = ax_combined.twinx() |
| | ax2.plot(epochs, val_accuracies, 'r-', label='Val Accuracy', linewidth=2, alpha=0.8) |
| | ax2.set_ylabel('Accuracy (%)', color='r') |
| | ax2.tick_params(axis='y', labelcolor='r') |
| | |
| | ax_combined.set_title('Loss vs Accuracy', fontweight='bold') |
| | ax_combined.grid(True, alpha=0.3) |
| | |
| | plt.tight_layout() |
| | plt.savefig('./outputs/training_analysis.png', dpi=300, bbox_inches='tight') |
| | plt.close() |
| | |
| | |
| | plt.figure(figsize=(12, 6)) |
| | plt.plot(epochs, val_accuracies, 'r-', linewidth=2, alpha=0.8) |
| | plt.fill_between(epochs, val_accuracies, alpha=0.3) |
| | plt.title('Validation Accuracy Progress', fontsize=14, fontweight='bold') |
| | plt.xlabel('Epoch') |
| | plt.ylabel('Accuracy (%)') |
| | plt.grid(True, alpha=0.3) |
| | plt.axhline(y=max(val_accuracies), color='r', linestyle='--', alpha=0.7, |
| | label=f'Peak Accuracy: {max(val_accuracies):.2f}%') |
| | plt.legend() |
| | plt.tight_layout() |
| | plt.savefig('./outputs/accuracy_detail.png', dpi=300, bbox_inches='tight') |
| | plt.close() |
| | |
| | logger.info("π Saved enhanced training visualizations") |
| |
|
| |
|
| | def save_classes(classes): |
| | """Save Food101 class names with better formatting""" |
| | os.makedirs('./outputs', exist_ok=True) |
| | |
| | with open('./outputs/food101_classes.txt', 'w') as f: |
| | f.write("Food101 Classes (101 total)\n") |
| | f.write("=" * 30 + "\n\n") |
| | for i, cls in enumerate(sorted(classes), 1): |
| | f.write(f"{i:3d}. {cls.replace('_', ' ').title()}\n") |
| | |
| | |
| | with open('./outputs/food101_classes_simple.txt', 'w') as f: |
| | for cls in sorted(classes): |
| | f.write(f"{cls}\n") |
| | |
| | logger.info("π Saved class names to ./outputs/") |
| |
|
| |
|
| | def print_system_info(): |
| | """Print system information for debugging""" |
| | logger.info("π₯οΈ System Information:") |
| | logger.info(f"PyTorch version: {torch.__version__}") |
| | logger.info(f"CUDA available: {torch.cuda.is_available()}") |
| | if torch.cuda.is_available(): |
| | logger.info(f"CUDA version: {torch.version.cuda}") |
| | logger.info(f"GPU: {torch.cuda.get_device_name()}") |
| | logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| | logger.info(f"Number of CPU cores: {os.cpu_count()}") |
| |
|
| |
|
| | |
| | |
| | |
| | def main(): |
| | print_system_info() |
| | |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | logger.info(f"Using device: {device}") |
| | |
| | try: |
| | |
| | logger.info("π₯ Loading Food101 dataset...") |
| | train_loader, val_loader, test_loader, classes = get_food101_loaders(batch_size=64, num_workers=8) |
| | save_classes(classes) |
| | |
| | |
| | logger.info("ποΈ Building ResNet50...") |
| | model = ResNet50(num_classes=101).to(device) |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | logger.info(f"Total parameters: {total_params/1e6:.1f}M") |
| | logger.info(f"Trainable parameters: {trainable_params/1e6:.1f}M") |
| | |
| | |
| | if hasattr(torch, 'compile'): |
| | logger.info("π Compiling model for faster training...") |
| | model = torch.compile(model) |
| | |
| | |
| | best_val_acc, losses, accuracies = train_model( |
| | model, train_loader, val_loader, test_loader, device, |
| | num_epochs=100, resume_from='./outputs/food101_resnet50_best.pth' if os.path.exists('./outputs/food101_resnet50_best.pth') else None |
| | ) |
| | |
| | logger.info(f"\nπ TRAINING COMPLETE!") |
| | logger.info(f"π Best Validation Accuracy: {best_val_acc:.2f}%") |
| | logger.info(f"\nπ SAVED FILES:") |
| | logger.info(f" β’ ./outputs/food101_resnet50_best.pth (best checkpoint)") |
| | logger.info(f" β’ ./outputs/food101_resnet50_best_weights.pth (best weights only)") |
| | logger.info(f" β’ ./outputs/food101_resnet50_final.pth (final checkpoint)") |
| | logger.info(f" β’ ./outputs/food101_resnet50_final_weights.pth (final weights only)") |
| | logger.info(f" β’ ./outputs/training_analysis.png (comprehensive plots)") |
| | logger.info(f" β’ ./outputs/accuracy_detail.png (detailed accuracy)") |
| | logger.info(f" β’ ./outputs/food101_classes.txt (formatted class list)") |
| | logger.info(f" β’ ./outputs/food101_classes_simple.txt (simple class list)") |
| | |
| | except Exception as e: |
| | logger.error(f"β Training failed with error: {e}") |
| | raise |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |