File size: 6,129 Bytes
5325fcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum
import logging
from pathlib import Path
import re
import typing as tp

import flashy
import torch

from ..environment import AudioCraftEnvironment


logger = logging.getLogger(__name__)


class CheckpointSource(Enum):
    CURRENT_XP = "current_xp"
    PRETRAINED = "pretrained"
    OTHER = "other"


def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
    """Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
    `checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
    'best' for the best checkpoint or the epoch number.

    Args:
        name (str, optional): Name suffix for the checkpoint file stem.
        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
        use_fsdp (bool): Whether the calling solver relies on FSDP.
    Returns:
        str: The checkpoint name.
    """
    suffix = ''
    if rank is None:
        rank = flashy.distrib.rank()
    if rank > 0 and use_fsdp:
        suffix = '.' + str(rank)
    name_part = ''
    if name is not None:
        name_part = f'_{name}'
    return f'checkpoint{name_part}.th{suffix}'


def is_sharded_checkpoint(path: Path) -> bool:
    """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
    return re.search(r'\.th\.\d+$', path.name) is not None


def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
                            use_fsdp: bool = False) -> tp.Optional[Path]:
    """Resolve a given checkpoint path for a provided dora sig or path.

    Args:
        sig_or_path (Path or str): Checkpoint path or dora signature.
        name (str, optional): Name suffix for the checkpoint file stem.
        rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
        use_fsdp (bool): Whether the calling solver relies on FSDP.
    Returns:
        Path, optional: Resolved checkpoint path, if it exists.
    """
    from audiocraft import train
    xps_root = train.main.dora.dir / 'xps'
    sig_or_path = str(sig_or_path)
    if sig_or_path.startswith('//sig/'):
        sig = sig_or_path[len('//sig/'):]
        path = xps_root / sig
    else:
        path = Path(sig_or_path)
        path = AudioCraftEnvironment.resolve_reference_path(path)

    if path.is_dir():
        path = path / checkpoint_name(name, use_fsdp=use_fsdp)

    if path.exists():
        return path
    else:
        return None


def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
    """Load state from checkpoints at the specified checkpoint path."""
    if is_sharded:
        rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
        if rank0_checkpoint_path.exists():
            check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
    state = torch.load(checkpoint_path, 'cpu')
    logger.info("Checkpoint loaded from %s", checkpoint_path)
    return state


def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
    """Save state to disk to the specified checkpoint_path."""
    _safe_save_checkpoint(state, checkpoint_path, is_sharded)
    logger.info("Checkpoint saved to %s", checkpoint_path)


def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
    """Flush checkpoints to only keep last N checkpoints."""
    if keep_last is None or keep_last <= 0:
        return
    checkpoint_dir = checkpoint_path.parent
    suffix = ''
    if flashy.distrib.rank() > 0:
        suffix = f'.{flashy.distrib.rank()}'
    checkpoint_files_with_epoch = []
    for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
        epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
        if epoch_part.isdigit():
            checkpoint_files_with_epoch.append((path, int(epoch_part)))
    checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
    total_to_flush = max(0, len(checkpoint_files) - keep_last)
    files_to_flush = checkpoint_files[:total_to_flush]
    for path in files_to_flush:
        logger.debug("Removing checkpoint: %s", str(path))
        path.unlink(missing_ok=True)


def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
    """Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
    # Finish the work of a previous run that got interrupted while dumping.
    old_path = Path(str(checkpoint_path) + '.old')
    if old_path.exists():
        raise RuntimeError(
            f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
    token = Path(str(rank0_checkpoint_path) + '.tmp.done')
    tmp_path = Path(str(checkpoint_path) + '.tmp')
    if token.exists():
        if tmp_path.exists():
            tmp_path.rename(checkpoint_path)
    flashy.distrib.barrier()
    if flashy.distrib.is_rank_zero() and token.exists():
        token.unlink()


def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
    """Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
    def _barrier_if_sharded():
        if is_sharded:
            flashy.distrib.barrier()

    if flashy.distrib.is_rank_zero():
        token = Path(str(checkpoint_path) + '.tmp.done')
        if token.exists():
            token.unlink()
    _barrier_if_sharded()
    with flashy.utils.write_and_rename(checkpoint_path) as f:
        torch.save(state, f)
        _barrier_if_sharded()
        if flashy.distrib.is_rank_zero():
            token.touch()
        _barrier_if_sharded()
    _barrier_if_sharded()
    if flashy.distrib.rank() == 0:
        token.unlink()