Spaces:
Sleeping
Sleeping
import wandb | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from typing import Dict, Any, Optional | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sklearn.metrics import confusion_matrix | |
from utils.data_loader import get_cifar10_info | |
class WandbLogger: | |
"""Minimal yet powerful W&B integration for FAANG-level ML projects.""" | |
def __init__(self, project: str = "cifar10-benchmark", entity: Optional[str] = None): | |
self.project = project | |
self.entity = entity | |
self.run = None | |
def init_experiment(self, config: Dict[str, Any], model: nn.Module, model_name: str): | |
"""Initialize W&B run with model architecture tracking.""" | |
# Auto-detect model stats for config | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
enhanced_config = { | |
**config, | |
'model_name': model_name, | |
'total_params': total_params, | |
'trainable_params': trainable_params, | |
'model_size_mb': total_params * 4 / (1024 ** 2), | |
'architecture': str(model.__class__.__name__) | |
} | |
self.run = wandb.init( | |
project=self.project, | |
entity=self.entity, | |
config=enhanced_config, | |
name=f"{model_name}-{wandb.util.generate_id()}" | |
) | |
# Log model architecture | |
wandb.watch(model, log_freq=100, log_graph=True) | |
return self.run | |
def log_metrics(self, metrics: Dict[str, float], step: int): | |
"""Log training metrics with automatic prefixing.""" | |
wandb.log(metrics, step=step) | |
def log_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray, epoch: int): | |
"""Log confusion matrix as W&B image.""" | |
cifar10_info = get_cifar10_info() | |
cm = confusion_matrix(y_true, y_pred) | |
plt.figure(figsize=(8, 6)) | |
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
xticklabels=cifar10_info['class_names'], | |
yticklabels=cifar10_info['class_names']) | |
plt.title(f'Confusion Matrix - Epoch {epoch}') | |
plt.tight_layout() | |
wandb.log({ | |
"confusion_matrix": wandb.Image(plt), | |
"epoch": epoch | |
}) | |
plt.close() | |
def log_model_checkpoint(self, model: nn.Module, optimizer, epoch: int, | |
metrics: Dict[str, float], is_best: bool = False): | |
"""Log model checkpoint with metadata.""" | |
checkpoint = { | |
'epoch': epoch, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
**metrics | |
} | |
filename = f"model_epoch_{epoch}.pth" | |
torch.save(checkpoint, filename) | |
artifact = wandb.Artifact( | |
name=f"model-{self.run.id}", | |
type="model", | |
metadata={"epoch": epoch, "is_best": is_best, **metrics} | |
) | |
artifact.add_file(filename) | |
wandb.log_artifact(artifact) | |
def finish(self): | |
"""Cleanup W&B run.""" | |
if self.run: | |
wandb.finish() | |
def create_hyperparameter_sweep(): | |
"""FAANG-level hyperparameter sweep configuration.""" | |
return { | |
'method': 'bayes', | |
'metric': {'name': 'val_accuracy', 'goal': 'maximize'}, | |
'parameters': { | |
'learning_rate': {'distribution': 'log_uniform', 'min': 1e-5, 'max': 1e-2}, | |
'batch_size': {'values': [32, 64, 128]}, | |
'weight_decay': {'distribution': 'log_uniform', 'min': 1e-6, 'max': 1e-3}, | |
'optimizer': {'values': ['adamw', 'sgd']}, | |
'scheduler': {'values': ['cosine', 'step']}, | |
'dropout_rate': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5} | |
} | |
} | |
def run_hyperparameter_sweep(train_fn, sweep_config: Dict[str, Any], count: int = 20): | |
"""Execute hyperparameter sweep with W&B.""" | |
sweep_id = wandb.sweep(sweep_config, project="cifar10-benchmark") | |
wandb.agent(sweep_id, train_fn, count=count) | |
# Integration with existing training loop | |
def enhanced_train_step(model, train_loader, val_loader, optimizer, criterion, | |
scheduler, num_epochs, device, logger: WandbLogger): | |
"""Enhanced training with W&B logging.""" | |
model.to(device) | |
best_val_acc = 0.0 | |
for epoch in range(num_epochs): | |
# Training | |
model.train() | |
train_loss, train_acc = 0.0, 0.0 | |
for batch_idx, (inputs, targets) in enumerate(train_loader): | |
inputs, targets = inputs.to(device), targets.to(device) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, targets) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
train_acc += (outputs.argmax(1) == targets).float().mean().item() | |
# Validation | |
model.eval() | |
val_loss, val_acc = 0.0, 0.0 | |
all_preds, all_targets = [], [] | |
with torch.no_grad(): | |
for inputs, targets in val_loader: | |
inputs, targets = inputs.to(device), targets.to(device) | |
outputs = model(inputs) | |
loss = criterion(outputs, targets) | |
val_loss += loss.item() | |
val_acc += (outputs.argmax(1) == targets).float().mean().item() | |
all_preds.extend(outputs.argmax(1).cpu().numpy()) | |
all_targets.extend(targets.cpu().numpy()) | |
# Normalize metrics | |
train_loss /= len(train_loader) | |
train_acc /= len(train_loader) | |
val_loss /= len(val_loader) | |
val_acc /= len(val_loader) | |
scheduler.step() | |
# Log to W&B | |
logger.log_metrics({ | |
'epoch': epoch, | |
'train_loss': train_loss, | |
'train_accuracy': train_acc * 100, | |
'val_loss': val_loss, | |
'val_accuracy': val_acc * 100, | |
'learning_rate': optimizer.param_groups[0]['lr'] | |
}, step=epoch) | |
# Log confusion matrix every 10 epochs | |
if (epoch + 1) % 10 == 0: | |
logger.log_confusion_matrix(all_targets, all_preds, epoch) | |
# Save best model | |
is_best = val_acc > best_val_acc | |
if is_best: | |
best_val_acc = val_acc | |
logger.log_model_checkpoint( | |
model, optimizer, epoch, | |
{'val_accuracy': val_acc, 'val_loss': val_loss}, | |
is_best=True | |
) | |
print(f"Epoch {epoch+1}/{num_epochs} | " | |
f"Train: {train_loss:.4f}/{train_acc:.3f} | " | |
f"Val: {val_loss:.4f}/{val_acc:.3f}") | |
return best_val_acc |