|
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum |
|
import logging |
|
from pathlib import Path |
|
import re |
|
import typing as tp |
|
|
|
import flashy |
|
import torch |
|
|
|
from ..environment import AudioCraftEnvironment |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CheckpointSource(Enum): |
|
CURRENT_XP = "current_xp" |
|
PRETRAINED = "pretrained" |
|
OTHER = "other" |
|
|
|
|
|
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str: |
|
"""Checkpoint name formatted for all use in AudioCraft codebase and has the following format: |
|
`checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint, |
|
'best' for the best checkpoint or the epoch number. |
|
|
|
Args: |
|
name (str, optional): Name suffix for the checkpoint file stem. |
|
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. |
|
use_fsdp (bool): Whether the calling solver relies on FSDP. |
|
Returns: |
|
str: The checkpoint name. |
|
""" |
|
suffix = '' |
|
if rank is None: |
|
rank = flashy.distrib.rank() |
|
if rank > 0 and use_fsdp: |
|
suffix = '.' + str(rank) |
|
name_part = '' |
|
if name is not None: |
|
name_part = f'_{name}' |
|
return f'checkpoint{name_part}.th{suffix}' |
|
|
|
|
|
def is_sharded_checkpoint(path: Path) -> bool: |
|
"""Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.""" |
|
return re.search(r'\.th\.\d+$', path.name) is not None |
|
|
|
|
|
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None, |
|
use_fsdp: bool = False) -> tp.Optional[Path]: |
|
"""Resolve a given checkpoint path for a provided dora sig or path. |
|
|
|
Args: |
|
sig_or_path (Path or str): Checkpoint path or dora signature. |
|
name (str, optional): Name suffix for the checkpoint file stem. |
|
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. |
|
use_fsdp (bool): Whether the calling solver relies on FSDP. |
|
Returns: |
|
Path, optional: Resolved checkpoint path, if it exists. |
|
""" |
|
from audiocraft import train |
|
xps_root = train.main.dora.dir / 'xps' |
|
sig_or_path = str(sig_or_path) |
|
if sig_or_path.startswith('//sig/'): |
|
sig = sig_or_path[len('//sig/'):] |
|
path = xps_root / sig |
|
else: |
|
path = Path(sig_or_path) |
|
path = AudioCraftEnvironment.resolve_reference_path(path) |
|
|
|
if path.is_dir(): |
|
path = path / checkpoint_name(name, use_fsdp=use_fsdp) |
|
|
|
if path.exists(): |
|
return path |
|
else: |
|
return None |
|
|
|
|
|
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any: |
|
"""Load state from checkpoints at the specified checkpoint path.""" |
|
if is_sharded: |
|
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False) |
|
if rank0_checkpoint_path.exists(): |
|
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path) |
|
state = torch.load(checkpoint_path, 'cpu') |
|
logger.info("Checkpoint loaded from %s", checkpoint_path) |
|
return state |
|
|
|
|
|
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: |
|
"""Save state to disk to the specified checkpoint_path.""" |
|
_safe_save_checkpoint(state, checkpoint_path, is_sharded) |
|
logger.info("Checkpoint saved to %s", checkpoint_path) |
|
|
|
|
|
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None: |
|
"""Flush checkpoints to only keep last N checkpoints.""" |
|
if keep_last is None or keep_last <= 0: |
|
return |
|
checkpoint_dir = checkpoint_path.parent |
|
suffix = '' |
|
if flashy.distrib.rank() > 0: |
|
suffix = f'.{flashy.distrib.rank()}' |
|
checkpoint_files_with_epoch = [] |
|
for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'): |
|
epoch_part = path.name.split('.', 1)[0].split('_', 1)[1] |
|
if epoch_part.isdigit(): |
|
checkpoint_files_with_epoch.append((path, int(epoch_part))) |
|
checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))] |
|
total_to_flush = max(0, len(checkpoint_files) - keep_last) |
|
files_to_flush = checkpoint_files[:total_to_flush] |
|
for path in files_to_flush: |
|
logger.debug("Removing checkpoint: %s", str(path)) |
|
path.unlink(missing_ok=True) |
|
|
|
|
|
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None: |
|
"""Check sharded checkpoint state, ensuring the checkpoints are not corrupted.""" |
|
|
|
old_path = Path(str(checkpoint_path) + '.old') |
|
if old_path.exists(): |
|
raise RuntimeError( |
|
f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.") |
|
token = Path(str(rank0_checkpoint_path) + '.tmp.done') |
|
tmp_path = Path(str(checkpoint_path) + '.tmp') |
|
if token.exists(): |
|
if tmp_path.exists(): |
|
tmp_path.rename(checkpoint_path) |
|
flashy.distrib.barrier() |
|
if flashy.distrib.is_rank_zero() and token.exists(): |
|
token.unlink() |
|
|
|
|
|
def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: |
|
"""Save checkpoints in a safe manner even with when sharded checkpoints across nodes.""" |
|
def _barrier_if_sharded(): |
|
if is_sharded: |
|
flashy.distrib.barrier() |
|
|
|
if flashy.distrib.is_rank_zero(): |
|
token = Path(str(checkpoint_path) + '.tmp.done') |
|
if token.exists(): |
|
token.unlink() |
|
_barrier_if_sharded() |
|
with flashy.utils.write_and_rename(checkpoint_path) as f: |
|
torch.save(state, f) |
|
_barrier_if_sharded() |
|
if flashy.distrib.is_rank_zero(): |
|
token.touch() |
|
_barrier_if_sharded() |
|
_barrier_if_sharded() |
|
if flashy.distrib.rank() == 0: |
|
token.unlink() |
|
|