openpi-oft / src /openpi /training /checkpoints.py
Sichang0621's picture
Upload folder using huggingface_hub
ce5618e verified
from __future__ import annotations
import asyncio
import concurrent.futures as futures
import dataclasses
import logging
from typing import Protocol
from etils import epath
import jax
import orbax.checkpoint as ocp
import orbax.checkpoint.future as future
from openpi.shared import array_typing as at
import openpi.shared.normalize as _normalize
import openpi.training.data_loader as _data_loader
import openpi.training.utils as training_utils
def initialize_checkpoint_dir(
checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
) -> tuple[ocp.CheckpointManager, bool]:
checkpoint_dir = epath.Path(checkpoint_dir).resolve()
resuming = False
if checkpoint_dir.exists():
if overwrite:
checkpoint_dir.rmtree()
checkpoint_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
elif resume:
resuming = True
else:
raise FileExistsError(
f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
"to indicate how to handle it."
)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
mngr = ocp.CheckpointManager(
checkpoint_dir,
item_handlers={
"assets": CallbackHandler(),
"train_state": ocp.PyTreeCheckpointHandler(),
"params": ocp.PyTreeCheckpointHandler(),
},
options=ocp.CheckpointManagerOptions(
max_to_keep=1,
keep_period=keep_period,
create=False,
async_options=ocp.AsyncOptions(timeout_secs=7200),
),
)
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
# checkpoint, since it will fail.
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
resuming = False
return mngr, resuming
def save_state(
checkpoint_manager: ocp.CheckpointManager,
state: training_utils.TrainState,
data_loader: _data_loader.DataLoader,
step: int,
):
def save_assets(directory: epath.Path):
# Save the normalization stats.
data_config = data_loader.data_config()
norm_stats = data_config.norm_stats
if norm_stats is not None and data_config.asset_id is not None:
_normalize.save(directory / data_config.asset_id, norm_stats)
# Split params that can be used for inference into a separate item.
with at.disable_typechecking():
train_state, params = _split_params(state)
items = {
"assets": save_assets,
"train_state": train_state,
"params": {"params": params},
}
checkpoint_manager.save(step, items)
def restore_state(
checkpoint_manager: ocp.CheckpointManager,
state: training_utils.TrainState,
data_loader: _data_loader.DataLoader,
step: int | None = None,
) -> training_utils.TrainState:
del data_loader
with at.disable_typechecking():
# Split params that can be used for inference into a separate item.
train_state, params = _split_params(state)
restored = checkpoint_manager.restore(
step,
items={
"train_state": train_state,
"params": {"params": params},
},
)
return _merge_params(restored["train_state"], restored["params"])
def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:
norm_stats_dir = epath.Path(assets_dir) / asset_id
norm_stats = _normalize.load(norm_stats_dir)
logging.info(f"Loaded norm stats from {norm_stats_dir}")
return norm_stats
class Callback(Protocol):
def __call__(self, directory: epath.Path) -> None: ...
class CallbackHandler(ocp.AsyncCheckpointHandler):
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
def save(self, directory: epath.Path, args: CallbackSave):
if jax.process_index() == 0:
args.callback(directory)
async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
def restore(self, *args, **kwargs):
raise NotImplementedError("CallbackHandler does not support restore")
@ocp.args.register_with_handler(CallbackHandler, for_save=True)
@dataclasses.dataclass
class CallbackSave(ocp.args.CheckpointArgs):
callback: Callback
@ocp.args.register_with_handler(CallbackHandler, for_restore=True)
class CallbackRestore(ocp.args.CheckpointArgs): ...
def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:
if state.ema_params is not None:
params = state.ema_params
train_state = dataclasses.replace(state, ema_params=None)
else:
params = state.params
train_state = dataclasses.replace(state, params={})
return train_state, params
def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:
# Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.
if train_state.params:
return dataclasses.replace(train_state, ema_params=params["params"])
return dataclasses.replace(train_state, params=params["params"])