MusicGen / audiocraft /utils /checkpoint.py
reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
import logging
from pathlib import Path
import re
import typing as tp
import flashy
import torch
from ..environment import AudioCraftEnvironment
logger = logging.getLogger(__name__)
class CheckpointSource(Enum):
CURRENT_XP = "current_xp"
PRETRAINED = "pretrained"
OTHER = "other"
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
"""Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
`checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
'best' for the best checkpoint or the epoch number.
Args:
name (str, optional): Name suffix for the checkpoint file stem.
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
use_fsdp (bool): Whether the calling solver relies on FSDP.
Returns:
str: The checkpoint name.
"""
suffix = ''
if rank is None:
rank = flashy.distrib.rank()
if rank > 0 and use_fsdp:
suffix = '.' + str(rank)
name_part = ''
if name is not None:
name_part = f'_{name}'
return f'checkpoint{name_part}.th{suffix}'
def is_sharded_checkpoint(path: Path) -> bool:
"""Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
return re.search(r'\.th\.\d+$', path.name) is not None
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
use_fsdp: bool = False) -> tp.Optional[Path]:
"""Resolve a given checkpoint path for a provided dora sig or path.
Args:
sig_or_path (Path or str): Checkpoint path or dora signature.
name (str, optional): Name suffix for the checkpoint file stem.
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
use_fsdp (bool): Whether the calling solver relies on FSDP.
Returns:
Path, optional: Resolved checkpoint path, if it exists.
"""
from audiocraft import train
xps_root = train.main.dora.dir / 'xps'
sig_or_path = str(sig_or_path)
if sig_or_path.startswith('//sig/'):
sig = sig_or_path[len('//sig/'):]
path = xps_root / sig
else:
path = Path(sig_or_path)
path = AudioCraftEnvironment.resolve_reference_path(path)
if path.is_dir():
path = path / checkpoint_name(name, use_fsdp=use_fsdp)
if path.exists():
return path
else:
return None
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
"""Load state from checkpoints at the specified checkpoint path."""
if is_sharded:
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
if rank0_checkpoint_path.exists():
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
state = torch.load(checkpoint_path, 'cpu')
logger.info("Checkpoint loaded from %s", checkpoint_path)
return state
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
"""Save state to disk to the specified checkpoint_path."""
_safe_save_checkpoint(state, checkpoint_path, is_sharded)
logger.info("Checkpoint saved to %s", checkpoint_path)
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
"""Flush checkpoints to only keep last N checkpoints."""
if keep_last is None or keep_last <= 0:
return
checkpoint_dir = checkpoint_path.parent
suffix = ''
if flashy.distrib.rank() > 0:
suffix = f'.{flashy.distrib.rank()}'
checkpoint_files_with_epoch = []
for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
if epoch_part.isdigit():
checkpoint_files_with_epoch.append((path, int(epoch_part)))
checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
total_to_flush = max(0, len(checkpoint_files) - keep_last)
files_to_flush = checkpoint_files[:total_to_flush]
for path in files_to_flush:
logger.debug("Removing checkpoint: %s", str(path))
path.unlink(missing_ok=True)
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
"""Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
# Finish the work of a previous run that got interrupted while dumping.
old_path = Path(str(checkpoint_path) + '.old')
if old_path.exists():
raise RuntimeError(
f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
token = Path(str(rank0_checkpoint_path) + '.tmp.done')
tmp_path = Path(str(checkpoint_path) + '.tmp')
if token.exists():
if tmp_path.exists():
tmp_path.rename(checkpoint_path)
flashy.distrib.barrier()
if flashy.distrib.is_rank_zero() and token.exists():
token.unlink()
def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
"""Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
def _barrier_if_sharded():
if is_sharded:
flashy.distrib.barrier()
if flashy.distrib.is_rank_zero():
token = Path(str(checkpoint_path) + '.tmp.done')
if token.exists():
token.unlink()
_barrier_if_sharded()
with flashy.utils.write_and_rename(checkpoint_path) as f:
torch.save(state, f)
_barrier_if_sharded()
if flashy.distrib.is_rank_zero():
token.touch()
_barrier_if_sharded()
_barrier_if_sharded()
if flashy.distrib.rank() == 0:
token.unlink()