Spaces:
Sleeping
Sleeping
import os | |
from datetime import datetime | |
from typing import Optional, Tuple | |
import glob | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torchvision import transforms, models, datasets | |
from pytorch_lightning import LightningModule, Trainer | |
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar | |
from loguru import logger | |
class CustomProgressBar(TQDMProgressBar): | |
def __init__(self): | |
super().__init__() | |
self.enable = True | |
def on_train_epoch_start(self, trainer, pl_module): | |
super().on_train_epoch_start(trainer, pl_module) | |
logger.info(f"\n{'='*20} Epoch {trainer.current_epoch} {'='*20}") | |
class ImageNetModule(LightningModule): | |
def __init__( | |
self, | |
learning_rate: float = 0.1, | |
momentum: float = 0.9, | |
weight_decay: float = 1e-4, | |
batch_size: int = 256, | |
num_workers: int = 16, | |
max_epochs: int = 90, | |
train_path: str = "path/to/imagenet", | |
val_path: str = "path/to/imagenet", | |
checkpoint_dir: str = "checkpoints" | |
): | |
super().__init__() | |
# self.save_hyperparameters() | |
# Model | |
self.model = models.resnet50(weights=None) | |
# Training parameters | |
self.learning_rate = learning_rate | |
self.momentum = momentum | |
self.weight_decay = weight_decay | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.max_epochs = max_epochs | |
self.train_path = train_path | |
self.val_path = val_path | |
self.checkpoint_dir = checkpoint_dir | |
# Metrics tracking | |
self.training_step_outputs = [] | |
self.validation_step_outputs = [] | |
self.best_val_acc = 0.0 | |
# Set up transforms | |
self.train_transforms = transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
self.val_transforms = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
images, labels = batch | |
outputs = self(images) | |
loss = F.cross_entropy(outputs, labels) | |
# Calculate accuracy | |
_, predicted = torch.max(outputs.data, 1) | |
correct = (predicted == labels).sum().item() | |
accuracy = (correct / labels.size(0))*100 | |
# Log metrics for this step | |
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('train_acc', accuracy, on_step=False, on_epoch=True, prog_bar=True) | |
self.training_step_outputs.append({ | |
'loss': loss.detach(), | |
'acc': torch.tensor(accuracy) | |
}) | |
return loss | |
def on_train_epoch_end(self): | |
if not self.training_step_outputs: | |
print("Warning: No training outputs available for this epoch") | |
return | |
avg_loss = torch.stack([x['loss'] for x in self.training_step_outputs]).mean() | |
avg_acc = torch.stack([x['acc'] for x in self.training_step_outputs]).mean() | |
# Get current learning rate | |
current_lr = self.optimizers().param_groups[0]['lr'] | |
logger.info(f"Training metrics - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}, LR: {current_lr:.6f}") | |
self.training_step_outputs.clear() | |
def validation_step(self, batch, batch_idx): | |
images, labels = batch | |
outputs = self(images) | |
loss = F.cross_entropy(outputs, labels) | |
# Calculate accuracy | |
_, predicted = torch.max(outputs.data, 1) | |
correct = (predicted == labels).sum().item() | |
accuracy = (correct / labels.size(0))*100 | |
# Log metrics for this step | |
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True) | |
self.log('val_acc', accuracy, on_step=False, on_epoch=True, prog_bar=True) | |
self.validation_step_outputs.append({ | |
'val_loss': loss.detach(), | |
'val_acc': torch.tensor(accuracy) | |
}) | |
return {'val_loss': loss, 'val_acc': accuracy} | |
def on_validation_epoch_end(self): | |
avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean() | |
avg_acc = torch.stack([x['val_acc'] for x in self.validation_step_outputs]).mean() | |
# Log final validation metrics | |
self.log('val_loss_epoch', avg_loss) | |
self.log('val_acc_epoch', avg_acc) | |
# Save checkpoint if validation accuracy improves | |
if avg_acc > self.best_val_acc: | |
self.best_val_acc = avg_acc | |
checkpoint_path = os.path.join( | |
self.checkpoint_dir, | |
f"resnet50-epoch{self.current_epoch:02d}-acc{avg_acc:.4f}.ckpt" | |
) | |
self.trainer.save_checkpoint(checkpoint_path) | |
logger.info(f"New best validation accuracy: {avg_acc:.4f}. Saved checkpoint to {checkpoint_path}") | |
logger.info(f"Validation metrics - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}") | |
self.validation_step_outputs.clear() | |
def train_dataloader(self): | |
train_dataset = datasets.ImageFolder( | |
self.train_path, | |
transform=self.train_transforms | |
) | |
return DataLoader( | |
train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
pin_memory=True | |
) | |
def val_dataloader(self): | |
val_dataset = datasets.ImageFolder( | |
self.val_path, | |
transform=self.val_transforms | |
) | |
return DataLoader( | |
val_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
pin_memory=True | |
) | |
def configure_optimizers(self): | |
optimizer = torch.optim.SGD( | |
self.parameters(), | |
lr=self.learning_rate, | |
momentum=self.momentum, | |
weight_decay=self.weight_decay | |
) | |
# OneCycleLR scheduler | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer, | |
max_lr=self.learning_rate, | |
epochs=self.max_epochs, | |
steps_per_epoch=len(self.train_dataloader()), | |
pct_start=0.3, | |
anneal_strategy='cos', | |
div_factor=25.0, | |
cycle_momentum=True, | |
base_momentum=0.85, | |
max_momentum=0.95, | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"interval": "step" | |
} | |
} | |
def setup_logging(log_dir="logs"): | |
os.makedirs(log_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
log_file = os.path.join(log_dir, f"training_{timestamp}.log") | |
logger.remove() | |
logger.add( | |
lambda msg: print(msg), | |
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | {message}", | |
colorize=True, | |
level="INFO" | |
) | |
logger.add( | |
log_file, | |
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", | |
level="INFO", | |
rotation="100 MB", | |
retention="30 days" | |
) | |
logger.info(f"Logging setup complete. Logs will be saved to: {log_file}") | |
return log_file | |
def find_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: | |
"""Find the latest checkpoint file using various possible naming patterns.""" | |
# Look for checkpoint files with different possible patterns | |
patterns = [ | |
"*.ckpt", # Generic checkpoint files | |
"resnet50-epoch*.ckpt", # Our custom format | |
"*epoch=*.ckpt", # PyTorch Lightning default format | |
"checkpoint_epoch*.ckpt" # Another common format | |
] | |
all_checkpoints = [] | |
for pattern in patterns: | |
checkpoint_pattern = os.path.join(checkpoint_dir, pattern) | |
all_checkpoints.extend(glob.glob(checkpoint_pattern)) | |
if not all_checkpoints: | |
logger.info("No existing checkpoints found.") | |
return None | |
def extract_info(checkpoint_path: str) -> Tuple[int, float]: | |
"""Extract epoch and optional accuracy from checkpoint filename.""" | |
filename = os.path.basename(checkpoint_path) | |
# Try different patterns to extract epoch number | |
epoch_patterns = [ | |
r'epoch=(\d+)', # matches epoch=X | |
r'epoch(\d+)', # matches epochX | |
r'epoch[_-](\d+)', # matches epoch_X or epoch-X | |
] | |
epoch = None | |
for pattern in epoch_patterns: | |
match = re.search(pattern, filename) | |
if match: | |
epoch = int(match.group(1)) | |
break | |
# If no epoch found, try to get from file modification time | |
if epoch is None: | |
epoch = int(os.path.getmtime(checkpoint_path)) | |
# Try to extract accuracy if present | |
acc_match = re.search(r'acc[_-]?([\d.]+)', filename) | |
acc = float(acc_match.group(1)) if acc_match else 0.0 | |
return epoch, acc | |
try: | |
latest_checkpoint = max(all_checkpoints, key=lambda x: extract_info(x)[0]) | |
epoch, acc = extract_info(latest_checkpoint) | |
logger.info(f"Found latest checkpoint: {latest_checkpoint}") | |
logger.info(f"Epoch: {epoch}" + (f", Accuracy: {acc:.4f}" if acc > 0 else "")) | |
return latest_checkpoint | |
except Exception as e: | |
logger.error(f"Error processing checkpoints: {str(e)}") | |
# If there's any error in parsing, return the most recently modified file | |
latest_checkpoint = max(all_checkpoints, key=os.path.getmtime) | |
logger.info(f"Falling back to most recently modified checkpoint: {latest_checkpoint}") | |
return latest_checkpoint | |
def main(): | |
checkpoint_dir = "/home/ec2-user/ebs/volumes/era_session9" | |
log_file = setup_logging(log_dir=checkpoint_dir) | |
logger.info("Starting training with configuration:") | |
logger.info(f"PyTorch version: {torch.__version__}") | |
logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
logger.info(f"CUDA device count: {torch.cuda.device_count()}") | |
logger.info(f"CUDA devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}") | |
# Find latest checkpoint | |
# latest_checkpoint = find_latest_checkpoint(checkpoint_dir) | |
latest_checkpoint = "/home/ec2-user/ebs/volumes/era_session9/resnet50-epoch18-acc53.7369.ckpt" | |
model = ImageNetModule( | |
learning_rate=0.156, | |
batch_size=256, | |
num_workers=16, | |
max_epochs=60, | |
train_path="/home/ec2-user/ebs/volumes/imagenet/ILSVRC/Data/CLS-LOC/train", | |
val_path="/home/ec2-user/ebs/volumes/imagenet/imagenet_validation", | |
checkpoint_dir=checkpoint_dir | |
) | |
logger.info(f"Model configuration:") | |
logger.info(f"Learning rate: {model.learning_rate}") | |
logger.info(f"Batch size: {model.batch_size}") | |
logger.info(f"Number of workers: {model.num_workers}") | |
logger.info(f"Max epochs: {model.max_epochs}") | |
progress_bar = CustomProgressBar() | |
trainer = Trainer( | |
max_epochs=60, | |
accelerator="gpu", | |
devices=4, | |
strategy="ddp", | |
precision=16, | |
callbacks=[progress_bar], | |
enable_progress_bar=True, | |
) | |
logger.info("Starting training") | |
try: | |
if latest_checkpoint: | |
logger.info(f"Resuming training from checkpoint: {latest_checkpoint}") | |
trainer.fit(model, ckpt_path=latest_checkpoint) | |
else: | |
logger.info("Starting training from scratch") | |
trainer.fit(model) | |
logger.info("Training completed successfully") | |
except Exception as e: | |
logger.error(f"Training failed with error: {str(e)}") | |
raise | |
finally: | |
logger.info(f"Training session ended. Log file: {log_file}") | |
if __name__ == "__main__": | |
main() | |
# pass |