|
"""
|
|
Checkpoint Manager for Mamba Swarm
|
|
Handles saving, loading, and managing model checkpoints
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
import shutil
|
|
import logging
|
|
import torch
|
|
import threading
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from dataclasses import dataclass, asdict
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import pickle
|
|
import hashlib
|
|
|
|
@dataclass
|
|
class CheckpointMetadata:
|
|
checkpoint_id: str
|
|
timestamp: float
|
|
epoch: int
|
|
step: int
|
|
loss: float
|
|
model_config: Dict[str, Any]
|
|
training_config: Dict[str, Any]
|
|
metrics: Dict[str, float]
|
|
file_path: str
|
|
file_size: int
|
|
checksum: str
|
|
|
|
class CheckpointManager:
|
|
"""Manages model checkpoints for Mamba Swarm"""
|
|
|
|
def __init__(self,
|
|
checkpoint_dir: str = "./checkpoints",
|
|
max_checkpoints: int = 10,
|
|
save_interval: int = 1000,
|
|
best_metric: str = "loss",
|
|
best_metric_mode: str = "min"):
|
|
|
|
self.checkpoint_dir = Path(checkpoint_dir)
|
|
self.max_checkpoints = max_checkpoints
|
|
self.save_interval = save_interval
|
|
self.best_metric = best_metric
|
|
self.best_metric_mode = best_metric_mode
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
self.metadata_file = self.checkpoint_dir / "metadata.json"
|
|
self.checkpoints_metadata: Dict[str, CheckpointMetadata] = {}
|
|
|
|
|
|
self.best_checkpoint_id: Optional[str] = None
|
|
self.best_metric_value: Optional[float] = None
|
|
|
|
|
|
self._load_metadata()
|
|
|
|
def save_checkpoint(self,
|
|
model_state: Dict[str, Any],
|
|
optimizer_state: Optional[Dict[str, Any]] = None,
|
|
scheduler_state: Optional[Dict[str, Any]] = None,
|
|
epoch: int = 0,
|
|
step: int = 0,
|
|
loss: float = 0.0,
|
|
metrics: Optional[Dict[str, float]] = None,
|
|
model_config: Optional[Dict[str, Any]] = None,
|
|
training_config: Optional[Dict[str, Any]] = None,
|
|
force_save: bool = False) -> str:
|
|
"""Save a checkpoint"""
|
|
|
|
|
|
if not force_save and step % self.save_interval != 0:
|
|
return None
|
|
|
|
|
|
checkpoint_id = self._generate_checkpoint_id(epoch, step)
|
|
|
|
|
|
checkpoint_data = {
|
|
"model_state": model_state,
|
|
"optimizer_state": optimizer_state,
|
|
"scheduler_state": scheduler_state,
|
|
"epoch": epoch,
|
|
"step": step,
|
|
"loss": loss,
|
|
"metrics": metrics or {},
|
|
"model_config": model_config or {},
|
|
"training_config": training_config or {},
|
|
"timestamp": time.time()
|
|
}
|
|
|
|
|
|
checkpoint_path = self.checkpoint_dir / f"{checkpoint_id}.pt"
|
|
|
|
with self.lock:
|
|
try:
|
|
torch.save(checkpoint_data, checkpoint_path)
|
|
|
|
|
|
file_size = checkpoint_path.stat().st_size
|
|
checksum = self._calculate_checksum(checkpoint_path)
|
|
|
|
|
|
metadata = CheckpointMetadata(
|
|
checkpoint_id=checkpoint_id,
|
|
timestamp=checkpoint_data["timestamp"],
|
|
epoch=epoch,
|
|
step=step,
|
|
loss=loss,
|
|
model_config=model_config or {},
|
|
training_config=training_config or {},
|
|
metrics=metrics or {},
|
|
file_path=str(checkpoint_path),
|
|
file_size=file_size,
|
|
checksum=checksum
|
|
)
|
|
|
|
|
|
self.checkpoints_metadata[checkpoint_id] = metadata
|
|
|
|
|
|
self._update_best_checkpoint(checkpoint_id, metrics or {"loss": loss})
|
|
|
|
|
|
self._cleanup_old_checkpoints()
|
|
|
|
|
|
self._save_metadata()
|
|
|
|
self.logger.info(f"Saved checkpoint {checkpoint_id} at step {step}")
|
|
return checkpoint_id
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to save checkpoint: {e}")
|
|
|
|
if checkpoint_path.exists():
|
|
checkpoint_path.unlink()
|
|
raise
|
|
|
|
def load_checkpoint(self, checkpoint_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
|
"""Load a checkpoint"""
|
|
|
|
|
|
if checkpoint_id is None:
|
|
checkpoint_id = self.best_checkpoint_id
|
|
|
|
if checkpoint_id is None or checkpoint_id not in self.checkpoints_metadata:
|
|
self.logger.warning(f"Checkpoint {checkpoint_id} not found")
|
|
return None
|
|
|
|
metadata = self.checkpoints_metadata[checkpoint_id]
|
|
checkpoint_path = Path(metadata.file_path)
|
|
|
|
if not checkpoint_path.exists():
|
|
self.logger.error(f"Checkpoint file {checkpoint_path} does not exist")
|
|
return None
|
|
|
|
try:
|
|
|
|
if not self._verify_checksum(checkpoint_path, metadata.checksum):
|
|
self.logger.error(f"Checkpoint {checkpoint_id} failed checksum verification")
|
|
return None
|
|
|
|
|
|
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
self.logger.info(f"Loaded checkpoint {checkpoint_id} from step {metadata.step}")
|
|
return checkpoint_data
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to load checkpoint {checkpoint_id}: {e}")
|
|
return None
|
|
|
|
def load_best_checkpoint(self) -> Optional[Dict[str, Any]]:
|
|
"""Load the best checkpoint"""
|
|
return self.load_checkpoint(self.best_checkpoint_id)
|
|
|
|
def load_latest_checkpoint(self) -> Optional[Dict[str, Any]]:
|
|
"""Load the most recent checkpoint"""
|
|
if not self.checkpoints_metadata:
|
|
return None
|
|
|
|
|
|
latest_id = max(self.checkpoints_metadata.keys(),
|
|
key=lambda x: self.checkpoints_metadata[x].timestamp)
|
|
|
|
return self.load_checkpoint(latest_id)
|
|
|
|
def list_checkpoints(self, sort_by: str = "timestamp") -> List[CheckpointMetadata]:
|
|
"""List all available checkpoints"""
|
|
checkpoints = list(self.checkpoints_metadata.values())
|
|
|
|
if sort_by == "timestamp":
|
|
checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
|
|
elif sort_by == "step":
|
|
checkpoints.sort(key=lambda x: x.step, reverse=True)
|
|
elif sort_by == "loss":
|
|
checkpoints.sort(key=lambda x: x.loss)
|
|
|
|
return checkpoints
|
|
|
|
def delete_checkpoint(self, checkpoint_id: str) -> bool:
|
|
"""Delete a specific checkpoint"""
|
|
if checkpoint_id not in self.checkpoints_metadata:
|
|
self.logger.warning(f"Checkpoint {checkpoint_id} not found")
|
|
return False
|
|
|
|
metadata = self.checkpoints_metadata[checkpoint_id]
|
|
checkpoint_path = Path(metadata.file_path)
|
|
|
|
with self.lock:
|
|
try:
|
|
|
|
if checkpoint_path.exists():
|
|
checkpoint_path.unlink()
|
|
|
|
|
|
del self.checkpoints_metadata[checkpoint_id]
|
|
|
|
|
|
if checkpoint_id == self.best_checkpoint_id:
|
|
self._find_new_best_checkpoint()
|
|
|
|
|
|
self._save_metadata()
|
|
|
|
self.logger.info(f"Deleted checkpoint {checkpoint_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to delete checkpoint {checkpoint_id}: {e}")
|
|
return False
|
|
|
|
def get_checkpoint_info(self, checkpoint_id: str) -> Optional[CheckpointMetadata]:
|
|
"""Get information about a specific checkpoint"""
|
|
return self.checkpoints_metadata.get(checkpoint_id)
|
|
|
|
def export_checkpoint(self, checkpoint_id: str, export_path: str) -> bool:
|
|
"""Export a checkpoint to a different location"""
|
|
if checkpoint_id not in self.checkpoints_metadata:
|
|
self.logger.error(f"Checkpoint {checkpoint_id} not found")
|
|
return False
|
|
|
|
metadata = self.checkpoints_metadata[checkpoint_id]
|
|
source_path = Path(metadata.file_path)
|
|
export_path = Path(export_path)
|
|
|
|
try:
|
|
|
|
shutil.copy2(source_path, export_path)
|
|
|
|
|
|
metadata_export_path = export_path.with_suffix('.json')
|
|
with open(metadata_export_path, 'w') as f:
|
|
json.dump(asdict(metadata), f, indent=2)
|
|
|
|
self.logger.info(f"Exported checkpoint {checkpoint_id} to {export_path}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to export checkpoint {checkpoint_id}: {e}")
|
|
return False
|
|
|
|
def import_checkpoint(self, checkpoint_path: str, metadata_path: Optional[str] = None) -> Optional[str]:
|
|
"""Import a checkpoint from external location"""
|
|
checkpoint_path = Path(checkpoint_path)
|
|
|
|
if not checkpoint_path.exists():
|
|
self.logger.error(f"Checkpoint file {checkpoint_path} does not exist")
|
|
return None
|
|
|
|
try:
|
|
|
|
if metadata_path:
|
|
with open(metadata_path, 'r') as f:
|
|
metadata_dict = json.load(f)
|
|
metadata = CheckpointMetadata(**metadata_dict)
|
|
else:
|
|
|
|
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
|
|
metadata = CheckpointMetadata(
|
|
checkpoint_id=self._generate_checkpoint_id(
|
|
checkpoint_data.get("epoch", 0),
|
|
checkpoint_data.get("step", 0)
|
|
),
|
|
timestamp=checkpoint_data.get("timestamp", time.time()),
|
|
epoch=checkpoint_data.get("epoch", 0),
|
|
step=checkpoint_data.get("step", 0),
|
|
loss=checkpoint_data.get("loss", 0.0),
|
|
model_config=checkpoint_data.get("model_config", {}),
|
|
training_config=checkpoint_data.get("training_config", {}),
|
|
metrics=checkpoint_data.get("metrics", {}),
|
|
file_path="",
|
|
file_size=0,
|
|
checksum=""
|
|
)
|
|
|
|
|
|
new_checkpoint_path = self.checkpoint_dir / f"{metadata.checkpoint_id}.pt"
|
|
shutil.copy2(checkpoint_path, new_checkpoint_path)
|
|
|
|
|
|
metadata.file_path = str(new_checkpoint_path)
|
|
metadata.file_size = new_checkpoint_path.stat().st_size
|
|
metadata.checksum = self._calculate_checksum(new_checkpoint_path)
|
|
|
|
with self.lock:
|
|
self.checkpoints_metadata[metadata.checkpoint_id] = metadata
|
|
self._update_best_checkpoint(metadata.checkpoint_id, metadata.metrics)
|
|
self._save_metadata()
|
|
|
|
self.logger.info(f"Imported checkpoint {metadata.checkpoint_id}")
|
|
return metadata.checkpoint_id
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to import checkpoint: {e}")
|
|
return None
|
|
|
|
def _generate_checkpoint_id(self, epoch: int, step: int) -> str:
|
|
"""Generate unique checkpoint ID"""
|
|
timestamp = int(time.time())
|
|
return f"checkpoint_epoch_{epoch}_step_{step}_{timestamp}"
|
|
|
|
def _calculate_checksum(self, file_path: Path) -> str:
|
|
"""Calculate MD5 checksum of file"""
|
|
hash_md5 = hashlib.md5()
|
|
with open(file_path, "rb") as f:
|
|
for chunk in iter(lambda: f.read(4096), b""):
|
|
hash_md5.update(chunk)
|
|
return hash_md5.hexdigest()
|
|
|
|
def _verify_checksum(self, file_path: Path, expected_checksum: str) -> bool:
|
|
"""Verify file checksum"""
|
|
actual_checksum = self._calculate_checksum(file_path)
|
|
return actual_checksum == expected_checksum
|
|
|
|
def _update_best_checkpoint(self, checkpoint_id: str, metrics: Dict[str, float]):
|
|
"""Update best checkpoint based on metrics"""
|
|
if self.best_metric not in metrics:
|
|
return
|
|
|
|
metric_value = metrics[self.best_metric]
|
|
|
|
if self.best_metric_value is None:
|
|
|
|
self.best_checkpoint_id = checkpoint_id
|
|
self.best_metric_value = metric_value
|
|
else:
|
|
|
|
is_better = False
|
|
if self.best_metric_mode == "min":
|
|
is_better = metric_value < self.best_metric_value
|
|
elif self.best_metric_mode == "max":
|
|
is_better = metric_value > self.best_metric_value
|
|
|
|
if is_better:
|
|
self.best_checkpoint_id = checkpoint_id
|
|
self.best_metric_value = metric_value
|
|
self.logger.info(f"New best checkpoint: {checkpoint_id} ({self.best_metric}: {metric_value})")
|
|
|
|
def _find_new_best_checkpoint(self):
|
|
"""Find new best checkpoint after deletion"""
|
|
if not self.checkpoints_metadata:
|
|
self.best_checkpoint_id = None
|
|
self.best_metric_value = None
|
|
return
|
|
|
|
best_id = None
|
|
best_value = None
|
|
|
|
for checkpoint_id, metadata in self.checkpoints_metadata.items():
|
|
if self.best_metric in metadata.metrics:
|
|
metric_value = metadata.metrics[self.best_metric]
|
|
|
|
if best_value is None:
|
|
best_id = checkpoint_id
|
|
best_value = metric_value
|
|
else:
|
|
is_better = False
|
|
if self.best_metric_mode == "min":
|
|
is_better = metric_value < best_value
|
|
elif self.best_metric_mode == "max":
|
|
is_better = metric_value > best_value
|
|
|
|
if is_better:
|
|
best_id = checkpoint_id
|
|
best_value = metric_value
|
|
|
|
self.best_checkpoint_id = best_id
|
|
self.best_metric_value = best_value
|
|
|
|
def _cleanup_old_checkpoints(self):
|
|
"""Remove old checkpoints to maintain max_checkpoints limit"""
|
|
if len(self.checkpoints_metadata) <= self.max_checkpoints:
|
|
return
|
|
|
|
|
|
sorted_checkpoints = sorted(
|
|
self.checkpoints_metadata.items(),
|
|
key=lambda x: x[1].timestamp
|
|
)
|
|
|
|
|
|
num_to_remove = len(sorted_checkpoints) - self.max_checkpoints
|
|
|
|
for i in range(num_to_remove):
|
|
checkpoint_id, metadata = sorted_checkpoints[i]
|
|
|
|
|
|
if checkpoint_id == self.best_checkpoint_id:
|
|
continue
|
|
|
|
|
|
checkpoint_path = Path(metadata.file_path)
|
|
if checkpoint_path.exists():
|
|
checkpoint_path.unlink()
|
|
|
|
del self.checkpoints_metadata[checkpoint_id]
|
|
self.logger.info(f"Cleaned up old checkpoint: {checkpoint_id}")
|
|
|
|
def _load_metadata(self):
|
|
"""Load checkpoint metadata from file"""
|
|
if not self.metadata_file.exists():
|
|
return
|
|
|
|
try:
|
|
with open(self.metadata_file, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
|
|
for checkpoint_id, metadata_dict in data.get("checkpoints", {}).items():
|
|
metadata = CheckpointMetadata(**metadata_dict)
|
|
self.checkpoints_metadata[checkpoint_id] = metadata
|
|
|
|
|
|
self.best_checkpoint_id = data.get("best_checkpoint_id")
|
|
self.best_metric_value = data.get("best_metric_value")
|
|
|
|
self.logger.info(f"Loaded metadata for {len(self.checkpoints_metadata)} checkpoints")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to load metadata: {e}")
|
|
|
|
def _save_metadata(self):
|
|
"""Save checkpoint metadata to file"""
|
|
try:
|
|
data = {
|
|
"checkpoints": {
|
|
checkpoint_id: asdict(metadata)
|
|
for checkpoint_id, metadata in self.checkpoints_metadata.items()
|
|
},
|
|
"best_checkpoint_id": self.best_checkpoint_id,
|
|
"best_metric_value": self.best_metric_value,
|
|
"last_updated": time.time()
|
|
}
|
|
|
|
|
|
temp_file = self.metadata_file.with_suffix('.tmp')
|
|
with open(temp_file, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
|
|
temp_file.replace(self.metadata_file)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to save metadata: {e}")
|
|
|
|
def get_storage_usage(self) -> Dict[str, Any]:
|
|
"""Get storage usage statistics"""
|
|
total_size = 0
|
|
checkpoint_count = len(self.checkpoints_metadata)
|
|
|
|
for metadata in self.checkpoints_metadata.values():
|
|
total_size += metadata.file_size
|
|
|
|
return {
|
|
"total_size_bytes": total_size,
|
|
"total_size_mb": total_size / (1024 * 1024),
|
|
"total_size_gb": total_size / (1024 * 1024 * 1024),
|
|
"checkpoint_count": checkpoint_count,
|
|
"average_size_mb": (total_size / checkpoint_count / (1024 * 1024)) if checkpoint_count > 0 else 0,
|
|
"checkpoint_directory": str(self.checkpoint_dir)
|
|
}
|
|
|
|
def cleanup_all_checkpoints(self):
|
|
"""Remove all checkpoints (dangerous operation)"""
|
|
with self.lock:
|
|
for metadata in self.checkpoints_metadata.values():
|
|
checkpoint_path = Path(metadata.file_path)
|
|
if checkpoint_path.exists():
|
|
checkpoint_path.unlink()
|
|
|
|
self.checkpoints_metadata.clear()
|
|
self.best_checkpoint_id = None
|
|
self.best_metric_value = None
|
|
|
|
|
|
if self.metadata_file.exists():
|
|
self.metadata_file.unlink()
|
|
|
|
self.logger.info("Cleaned up all checkpoints")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
checkpoint_manager = CheckpointManager(
|
|
checkpoint_dir="./test_checkpoints",
|
|
max_checkpoints=5,
|
|
save_interval=100
|
|
)
|
|
|
|
|
|
for step in range(0, 1000, 100):
|
|
model_state = {"layer_weights": torch.randn(10, 10)}
|
|
optimizer_state = {"param_groups": [{"lr": 0.001}]}
|
|
|
|
metrics = {
|
|
"loss": 1.0 - step / 1000.0,
|
|
"accuracy": step / 1000.0
|
|
}
|
|
|
|
checkpoint_id = checkpoint_manager.save_checkpoint(
|
|
model_state=model_state,
|
|
optimizer_state=optimizer_state,
|
|
step=step,
|
|
loss=metrics["loss"],
|
|
metrics=metrics,
|
|
force_save=True
|
|
)
|
|
|
|
print(f"Saved checkpoint: {checkpoint_id}")
|
|
|
|
|
|
print("\nAvailable checkpoints:")
|
|
for metadata in checkpoint_manager.list_checkpoints():
|
|
print(f" {metadata.checkpoint_id}: step {metadata.step}, loss {metadata.loss:.3f}")
|
|
|
|
|
|
best_checkpoint = checkpoint_manager.load_best_checkpoint()
|
|
print(f"\nLoaded best checkpoint: {checkpoint_manager.best_checkpoint_id}")
|
|
|
|
|
|
usage = checkpoint_manager.get_storage_usage()
|
|
print(f"\nStorage usage: {usage['total_size_mb']:.2f} MB ({usage['checkpoint_count']} checkpoints)")
|
|
|
|
|
|
checkpoint_manager.cleanup_all_checkpoints()
|
|
print("Cleaned up test checkpoints") |