subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass
from pathlib import Path, PosixPath, WindowsPath
from typing import Optional, Union
import lightning.fabric as fl
import lightning.pytorch as pl
from nemo.lightning import io
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME
from nemo.lightning.pytorch.strategies.utils import RestoreConfig
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import uninject_model_parallel_rank
from nemo.utils.msc_utils import import_multistorageclient, is_multistorageclient_url
# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == "nt":
BasePath = WindowsPath
else:
BasePath = PosixPath
def _try_restore_tokenizer(model, ckpt_path):
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.lightning.io import load_context
try:
tokenizer = load_context(ckpt_path, "model.tokenizer")
except ValueError as e:
logging.warning(
f"Encountered error while trying to restore tokenizer. Tokenizer is not restored. " f"Original error: {e}"
)
return model
if isinstance(tokenizer, TokenizerSpec):
model.tokenizer = tokenizer
model.__io__.tokenizer = tokenizer.__io__
else:
# Ignore if the ckpt doesn't have a tokenizer. type(tokenizer)==TrainerContext in this case.
logging.warning("Checkpoint does not have model.tokenizer field. Tokenizer is not restored.")
return model
@dataclass(kw_only=True)
class AutoResume:
"""Class that handles the logic for setting checkpoint paths and restoring from
checkpoints in NeMo.
Attributes:
restore_config (Optional[RestoreConfig]): Optional config for selectively restoring specific parts like model
weights, optimizer states, etc.
If the config contains a path from HF or another non-NeMo checkpoint format, the checkpoint will be
automatically converted to a NeMo compatible format.
resume_from_folder or the run's log_dir takes precedence over restore_config.
resume_from_directory (str): Path to the checkpointing directory to restore from.
resume_from_path (str): Path to a specific checkpoint to restore from.
resume_if_exists (bool): Whether this experiment is resuming from a previous run. If
True, it sets trainer._checkpoint_connector._ckpt_path so that the trainer should
auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}.
Defaults to False.
resume_past_end (bool): By default, AutoResume throws an error if resume_if_exists is
True and a checkpoint matching ``*end.ckpt`` indicating a previous training run
fully completed. Setting resume_past_end=True disables this behavior and loads the
last checkpoint.
resume_ignore_no_checkpoint (bool): AutoResume throws an error if resume_if_exists is
True and no checkpoint could be found. Setting resume_ignore_no_checkpoint=True
disables this behavior, in which case exp_manager will print a message and
continue without restoring.
"""
restore_config: Optional[RestoreConfig] = None
resume_from_directory: Optional[str] = None
resume_from_path: Optional[str] = None
resume_if_exists: bool = False
resume_past_end: bool = False
resume_ignore_no_checkpoint: bool = False
WEIGHTS_PATH = "weights"
def get_weights_path(self, path) -> Path:
"""Returns the path to the weights directory within the specified path.
Args:
path: The checkpoint directory path
Returns:
Path: A Path object pointing to the weights directory
"""
return path / self.WEIGHTS_PATH
def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
"""Sets up checkpoint restoration for the Pytorch Lightning trainer.
This method configures the trainer with the appropriate checkpoint path for resuming
training and handles loading model artifacts like tokenizers when specified.
Args:
trainer: The PyTorch Lightning trainer or Fabric instance
model: Optional model instance to load artifacts into
Raises:
NotImplementedError: If trainer is a Fabric instance (not yet supported)
"""
if isinstance(trainer, fl.Fabric):
raise NotImplementedError("Fabric is not supported yet.")
trainer_ckpt_path = self.get_trainer_ckpt_path(model)
if trainer_ckpt_path:
trainer.ckpt_path = trainer_ckpt_path
trainer.checkpoint_callback.last_model_path = trainer_ckpt_path
# Load artifacts
if getattr(self.restore_config, "load_artifacts", False):
if isinstance(trainer_ckpt_path, AdapterPath):
# load tokenizer from the base model during peft resume, in case the first peft checkpoint
# is deleted before the current peft checkpoint is saved
context_path = trainer_ckpt_path.base_model_path / "context"
if not context_path.exists():
context_path = trainer_ckpt_path.base_model_path
else:
context_path = self.get_context_path(model)
model = _try_restore_tokenizer(model, context_path)
elif self.restore_config:
new_path = self._extract_path(
path=self.restore_config.path,
)
assert not isinstance(new_path, AdapterPath), "AdapterPath is not supported for restore_config"
self.restore_config.path = str(new_path)
trainer.strategy.restore_config = self.restore_config
# Load artifacts
if self.restore_config.load_artifacts:
if isinstance(new_path, AdapterPath):
context_path = Path(new_path.base_model_path) / "context"
else:
context_path = new_path / "context"
if not context_path.is_dir():
context_path = new_path
_try_restore_tokenizer(model, context_path)
def _extract_path(self, path: str) -> BasePath:
if "://" in path:
assert path.startswith("nemo://"), "Only NeMo based paths starting with nemo:// are currently supported."
_, _path = path.split("://")
new_path = os.path.join(NEMO_MODELS_CACHE, _path)
else:
new_path = path
if isinstance(new_path, str):
new_path = Path(new_path)
return new_path
def _get_base_model_path_for_adapter(self, adapter_meta_path, model):
with open(adapter_meta_path, "r") as f:
metadata = json.load(f)
# Use the model_ckpt_path from metadata directly
base_model_path = Path(metadata["model_ckpt_path"])
# If base_model_path points to a specific checkpoint file, use its parent directory
if not base_model_path.is_dir() and base_model_path.exists():
base_model_path = base_model_path.parent
return base_model_path
def _find_trainer_ckpt_path(self) -> Optional[Path]:
from nemo.utils.exp_manager import NotFoundError, _filter_out_unfinished_checkpoints
app_state = AppState()
log_dir = app_state.log_dir
checkpoint = None
# Use <log_dir>/checkpoints/ unless `dirpath` is set
if self.resume_from_directory:
if is_multistorageclient_url(self.resume_from_directory):
msc = import_multistorageclient()
checkpoint_dir = msc.Path(self.resume_from_directory)
else:
checkpoint_dir = Path(self.resume_from_directory)
elif log_dir is not None:
checkpoint_dir = Path(Path(log_dir) / "checkpoints")
else: # ie. if log_dir is None
return None
# when using distributed checkpointing, checkpoint_dir is a directory of directories
# we check for this here
dist_checkpoints = [d for d in list(checkpoint_dir.glob("*")) if d.is_dir()]
end_dist_checkpoints = [d for d in dist_checkpoints if d.match("*end")]
last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")]
end_chkpt_cnt = len(end_dist_checkpoints)
end_checkpoints = _filter_out_unfinished_checkpoints(end_dist_checkpoints)
finished_end_chkpt_cnt = len(end_checkpoints)
if end_chkpt_cnt > 0 and finished_end_chkpt_cnt == 0:
raise ValueError(
"End checkpoint is unfinished and cannot be used to resume the training."
" Please remove the checkpoint manually to avoid unexpected cosequences, such as"
" restarting from scratch."
)
last_chkpt_cnt = len(last_dist_checkpoints)
last_checkpoints = _filter_out_unfinished_checkpoints(last_dist_checkpoints)
finished_last_chkpt_cnt = len(last_checkpoints)
if last_chkpt_cnt > 0 and finished_last_chkpt_cnt == 0:
raise ValueError(
"Last checkpoint is unfinished and cannot be used to resume the training."
" Please remove the checkpoint manually to avoid unexpected cosequences, such as"
" restarting from scratch. Hint: Iteration number can be added to the checkpoint name pattern"
" to maximize chance that there is at least one finished last checkpoint to resume from."
)
if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0):
if self.resume_ignore_no_checkpoint:
message = (
f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir "
f":{checkpoint_dir}. "
)
if not self.restore_config:
logging.warning(message + "Training from scratch.")
else:
logging.info(message + "Trying to resume from RestoreConfig.")
else:
if self.restore_config:
# resume_if_exists is True but run is not resumable. Do not fail and try to do selective restore
# later instead.
return None
else:
raise NotFoundError(
f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir "
f":{checkpoint_dir}. Cannot resume."
)
elif len(end_checkpoints) > 0:
if not self.resume_past_end:
raise ValueError(
f"Found {end_checkpoints[0]} indicating that the last training run has already completed."
)
if len(end_checkpoints) > 1:
if "mp_rank" in str(end_checkpoints[0]):
checkpoint = end_checkpoints[0]
else:
raise ValueError(f"Multiple checkpoints {end_checkpoints} that matches *end.ckpt.")
elif len(last_checkpoints) > 1:
if any([s for s in ["mp_rank", "tp_rank", "fsdp_shard"] if s in str(last_checkpoints[0])]):
checkpoint = last_checkpoints[0]
checkpoint = uninject_model_parallel_rank(checkpoint)
else:
# Select the checkpoint with the latest modified time
checkpoint = sorted(last_checkpoints, key=lambda pth: pth.lstat().st_mtime, reverse=True)[0]
logging.warning(
f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest "
f"modified time."
)
else:
checkpoint = last_checkpoints[0]
return checkpoint
def get_context_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]:
"""Retrieves the path to the context directory of a checkpoint.
The context directory contains serialized objects like tokenizers. This method
handles both cases where the context is directly in the checkpoint directory
or in a subdirectory called "context".
Args:
model: Optional model instance
Returns:
Optional[Path]: Path to the context directory if found, None otherwise
"""
checkpoint = None
app_state = AppState()
app_state.restore = self.resume_if_exists
if self.resume_if_exists:
checkpoint = self._find_trainer_ckpt_path()
if checkpoint:
maybe_context_path = checkpoint / "context"
if maybe_context_path.is_dir():
checkpoint = maybe_context_path
return checkpoint
def get_trainer_ckpt_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]:
"""Resolves the path to a checkpoint for resuming training.
This method handles various checkpoint sources with the following priority:
1. Explicit path specified in resume_from_path
2. Automatic discovery in the checkpoint directory when resume_if_exists=True
For adapter checkpoints (PEFT), it also retrieves the base model path from metadata.
Args:
model: Optional model instance
Returns:
Optional[Path]: Path to the checkpoint if found, or AdapterPath for PEFT checkpoints,
or None if no checkpoint is found or needed
"""
if self.resume_from_path:
if is_multistorageclient_url(self.resume_from_path):
msc = import_multistorageclient()
resume_from_path = msc.Path(self.resume_from_path)
else:
resume_from_path = Path(self.resume_from_path)
maybe_weights_path = self.get_weights_path(resume_from_path)
if maybe_weights_path.is_dir():
adapter_meta_path = maybe_weights_path / ADAPTER_META_FILENAME
if adapter_meta_path.exists():
# the resume_from_path is an adapter checkpoint
base_model_path = self._get_base_model_path_for_adapter(adapter_meta_path, model)
return AdapterPath(Path(self.resume_from_path), base_model_path=base_model_path)
else:
# the resume_from_path is not PEFT checkpoint
return maybe_weights_path
else:
return self.resume_from_path
checkpoint = None
app_state = AppState()
app_state.restore = self.resume_if_exists
if self.resume_if_exists:
checkpoint = self._find_trainer_ckpt_path()
if checkpoint:
maybe_weights_path = self.get_weights_path(checkpoint)
if maybe_weights_path.is_dir():
checkpoint = maybe_weights_path
if checkpoint:
adapter_meta_path = checkpoint / ADAPTER_META_FILENAME
if adapter_meta_path.exists():
base_model_path = self._get_base_model_path_for_adapter(adapter_meta_path, model)
return AdapterPath(checkpoint, base_model_path=base_model_path)
else:
return checkpoint
return None
class AdapterPath(BasePath):
"""Path object for adapter paths which include a field for the base model the adapters are trained on
to facilitate model loading."""
base_model_path: Optional[Path]
def __new__(cls, *args, base_model_path: Optional[Path] = None, **kwargs):
output = super().__new__(cls, *args, **kwargs)
output.base_model_path = base_model_path
return output
def __repr__(self):
return "{}({!r}, base_model_path={})".format(self.__class__.__name__, self.as_posix(), self.base_model_path)