Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,129 Bytes
9d0d223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
import logging
from pathlib import Path
import re
import typing as tp
import flashy
import torch
from ..environment import AudioCraftEnvironment
logger = logging.getLogger(__name__)
class CheckpointSource(Enum):
CURRENT_XP = "current_xp"
PRETRAINED = "pretrained"
OTHER = "other"
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
"""Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
`checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
'best' for the best checkpoint or the epoch number.
Args:
name (str, optional): Name suffix for the checkpoint file stem.
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
use_fsdp (bool): Whether the calling solver relies on FSDP.
Returns:
str: The checkpoint name.
"""
suffix = ''
if rank is None:
rank = flashy.distrib.rank()
if rank > 0 and use_fsdp:
suffix = '.' + str(rank)
name_part = ''
if name is not None:
name_part = f'_{name}'
return f'checkpoint{name_part}.th{suffix}'
def is_sharded_checkpoint(path: Path) -> bool:
"""Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
return re.search(r'\.th\.\d+$', path.name) is not None
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
use_fsdp: bool = False) -> tp.Optional[Path]:
"""Resolve a given checkpoint path for a provided dora sig or path.
Args:
sig_or_path (Path or str): Checkpoint path or dora signature.
name (str, optional): Name suffix for the checkpoint file stem.
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
use_fsdp (bool): Whether the calling solver relies on FSDP.
Returns:
Path, optional: Resolved checkpoint path, if it exists.
"""
from audiocraft import train
xps_root = train.main.dora.dir / 'xps'
sig_or_path = str(sig_or_path)
if sig_or_path.startswith('//sig/'):
sig = sig_or_path[len('//sig/'):]
path = xps_root / sig
else:
path = Path(sig_or_path)
path = AudioCraftEnvironment.resolve_reference_path(path)
if path.is_dir():
path = path / checkpoint_name(name, use_fsdp=use_fsdp)
if path.exists():
return path
else:
return None
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
"""Load state from checkpoints at the specified checkpoint path."""
if is_sharded:
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
if rank0_checkpoint_path.exists():
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
state = torch.load(checkpoint_path, 'cpu')
logger.info("Checkpoint loaded from %s", checkpoint_path)
return state
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
"""Save state to disk to the specified checkpoint_path."""
_safe_save_checkpoint(state, checkpoint_path, is_sharded)
logger.info("Checkpoint saved to %s", checkpoint_path)
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
"""Flush checkpoints to only keep last N checkpoints."""
if keep_last is None or keep_last <= 0:
return
checkpoint_dir = checkpoint_path.parent
suffix = ''
if flashy.distrib.rank() > 0:
suffix = f'.{flashy.distrib.rank()}'
checkpoint_files_with_epoch = []
for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
if epoch_part.isdigit():
checkpoint_files_with_epoch.append((path, int(epoch_part)))
checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
total_to_flush = max(0, len(checkpoint_files) - keep_last)
files_to_flush = checkpoint_files[:total_to_flush]
for path in files_to_flush:
logger.debug("Removing checkpoint: %s", str(path))
path.unlink(missing_ok=True)
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
"""Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
# Finish the work of a previous run that got interrupted while dumping.
old_path = Path(str(checkpoint_path) + '.old')
if old_path.exists():
raise RuntimeError(
f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
token = Path(str(rank0_checkpoint_path) + '.tmp.done')
tmp_path = Path(str(checkpoint_path) + '.tmp')
if token.exists():
if tmp_path.exists():
tmp_path.rename(checkpoint_path)
flashy.distrib.barrier()
if flashy.distrib.is_rank_zero() and token.exists():
token.unlink()
def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
"""Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
def _barrier_if_sharded():
if is_sharded:
flashy.distrib.barrier()
if flashy.distrib.is_rank_zero():
token = Path(str(checkpoint_path) + '.tmp.done')
if token.exists():
token.unlink()
_barrier_if_sharded()
with flashy.utils.write_and_rename(checkpoint_path) as f:
torch.save(state, f)
_barrier_if_sharded()
if flashy.distrib.is_rank_zero():
token.touch()
_barrier_if_sharded()
_barrier_if_sharded()
if flashy.distrib.rank() == 0:
token.unlink()
|