subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import signal
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional
import nemo_run as run
import yaml
from lightning.pytorch import Callback
from lightning.pytorch.loggers import WandbLogger
from nemo_run.core.serialization.yaml import YamlSerializer
from nemo.lightning.pytorch.callbacks import MemoryProfileCallback, NsysCallback, PreemptionCallback
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy
from nemo.utils import logging
from nemo.utils.import_utils import safe_import
res_module, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency')
# This file contains plugins based on NeMo-Run's run.Plugin API.
# Plugins operate both on a configured task and an executor at the same time, and are specific to NeMo-Run.
# If you are adding functionality that goes directly into the Pytorch Lightning trainer,
# you may consider adding a callback instead of a plugin.
def _merge_callbacks(partial: run.Partial, callbacks: list[run.Config[Callback]]):
if hasattr(partial, "trainer"):
if hasattr(partial.trainer, "callbacks") and partial.trainer.callbacks:
for callback in callbacks:
if callback not in partial.trainer.callbacks:
partial.trainer.callbacks.append(callback)
else:
partial.trainer.callbacks = copy.deepcopy(callbacks)
@dataclass(kw_only=True)
class PreemptionPlugin(run.Plugin):
"""
A plugin for setting up Preemption callback and preemption signals.
Args:
preempt_time (int): The time, in seconds, before the task's time limit at which the executor
will send a SIGTERM preemption signal. This allows tasks to be gracefully
stopped before reaching their time limit, reducing waste and
promoting fair resource usage. The default value is 60 seconds (1 minute).
This is only supported for ``run.SlurmExecutor``.
sig (signal.Signals): The signal to listen for. Defaults to signal.SIGTERM.
callbacks (list[run.Config[Callback]]): A list of callback configurations that the plugin
will merge with the task's existing callbacks.
By default, the list includes NeMo's preemption callback.
"""
preempt_time: int = 60
sig: signal.Signals = signal.SIGTERM
callbacks: list[run.Config[Callback]] = None
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the preemption plugin."""
if isinstance(task, run.Script):
logging.warning(
f"The {self.__class__.__name__} will have no effect on the task as it's an instance of run.Script"
)
return
if isinstance(executor, run.SlurmExecutor):
# Sends a SIGTERM self.preempt_time seconds before hitting time limit
logging.info(
f"{self.__class__.__name__} will send a {self.sig.name} {self.preempt_time} seconds before the job's time limit for your Slurm executor." # pylint: disable=C0301
)
executor.signal = f"{self.sig.value}@{self.preempt_time}"
callbacks = self.callbacks or [run.Config(PreemptionCallback, sig=self.sig)]
_merge_callbacks(task, callbacks=callbacks)
@dataclass(kw_only=True)
class FaultTolerancePlugin(run.Plugin):
"""
A plugin for setting up the fault tolerance callback from nvidia-resiliency-ext.
This plugin enables workload hang detection, automatic calculation of timeouts used for hang detection,
detection of rank(s) terminated due to an error and workload respawning in case of a failure.
Note: FaultTolerancePlugin does not work with the NsysPlugin.
Args:
num_in_job_restarts (int): Max number of restarts on failure, within the same job. Default is 3.
num_job_retries_on_failure (int): Max number of new job restarts on failure. Default is 2.
initial_rank_heartbeat_timeout (int): Timeouts are time intervals used by a rank monitor to detect
that a rank is not alive. This is the max timeout for the initial heartbeat. Default is 1800.
rank_heartbeat_timeout (int): This is the timeout for subsequent hearbeats after the initial heartbeat.
Default is 300.
"""
num_in_job_restarts: int = 3
num_job_retries_on_failure: int = 2
initial_rank_heartbeat_timeout: int = 1800
rank_heartbeat_timeout: int = 300
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the fault tolerance plugin."""
assert HAVE_RES, "nvidia-resiliency-ext.ptl_resiliency is required to use the FaultTolerancePlugin."
executor.launcher = run.FaultTolerance(
max_restarts=self.num_in_job_restarts,
initial_rank_heartbeat_timeout=self.initial_rank_heartbeat_timeout,
rank_heartbeat_timeout=self.rank_heartbeat_timeout,
)
executor.retries = self.num_job_retries_on_failure
assert isinstance(task, run.Partial)
callbacks = [
run.Config(
res_module.FaultToleranceCallback, autoresume=True, calculate_timeouts=True, exp_dir=task.log.log_dir
)
]
assert not executor.launcher.nsys_profile, "Nsys not supported with the FaultTolerancePlugin."
if hasattr(task, "trainer") and hasattr(task.trainer, "callbacks"):
assert all(
map(
lambda cb: not cb.__fn_or_cls__ == NsysCallback if "__fn_or_cls__" in dir(cb) else True,
task.trainer.callbacks,
)
), "Nsys not supported with FaultTolerancePlugin."
_merge_callbacks(task, callbacks=callbacks)
@dataclass(kw_only=True)
class NsysPlugin(run.Plugin):
"""
A plugin for nsys profiling.
The NsysPlugin allows you to profile your run using nsys.
You can specify when to start and end the profiling, on which ranks to run the profiling,
and what to trace during profiling.
Args:
start_step (int): The step at which to start the nsys profiling.
end_step (int): The step at which to end the nsys profiling.
ranks (Optional[list[int]]): The ranks on which to run the nsys profiling. If not specified,
profiling will be run on rank 0.
nsys_trace (Optional[list[str]]): The events to trace during profiling. If not specified,
'nvtx' and 'cuda' events will be traced.
"""
start_step: int
end_step: int
ranks: Optional[list[int]] = None
nsys_trace: Optional[list[str]] = None
gen_shape: bool = False
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the nsys profiling plugin."""
if isinstance(task, run.Partial):
nsys_callback = run.Config(
NsysCallback,
start_step=self.start_step,
end_step=self.end_step,
ranks=self.ranks or [0],
gen_shape=self.gen_shape,
)
callbacks: list[run.Config[Callback]] = [nsys_callback] # type: ignore
_merge_callbacks(task, callbacks=callbacks)
launcher = executor.get_launcher()
launcher.nsys_profile = True
launcher.nsys_trace = self.nsys_trace or ["nvtx", "cuda"]
if isinstance(executor, run.SlurmExecutor):
# NOTE: DO NOT change to f-string, `%q{}` is Slurm placeholder
launcher.nsys_filename = "profile_%p_%q{SLURM_JOB_ID}_node%q{SLURM_NODEID}_rank%q{SLURM_PROCID}"
@dataclass(kw_only=True)
class MemoryProfilePlugin(run.Plugin):
"""
A plugin for memory profiling.
The MemoryProfilePlugin allows you to profile a timeline of memory allocations during you run.
The memory profiling plugin creates snapshots during the entire training. You can specify
which ranks to run the profiling.
Args:
dir (str): Directory to store the memory profile dump .pickle files
ranks (Optional[list[int]]): The ranks on which to run the memory profiling. If not specified,
profiling will be run on rank 0.
"""
dir: str
ranks: Optional[list[int]] = None
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the memory profiling plugin."""
if isinstance(task, run.Partial):
memprof_callback = run.Config(
MemoryProfileCallback,
dir=self.dir,
ranks=self.ranks or [0],
)
callbacks: list[run.Config[Callback]] = [memprof_callback] # type: ignore
_merge_callbacks(task, callbacks=callbacks)
@dataclass(kw_only=True)
class WandbPlugin(run.Plugin):
"""
A plugin for setting up Weights & Biases.
This plugin sets a ``WandbLogger`` to ``NeMoLogger``'s ``wandb`` arg,
which in turn initializes the Pytorch Lightning `WandbLogger
<https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.WandbLogger.html>`_.
This plugin is only activated if the ``WANDB_API_KEY`` environment variable is set.
The ``WANDB_API_KEY`` environment variables will also be set in the executor's environment variables.
Follow https://docs.wandb.ai/quickstart to retrieve your ``WANDB_API_KEY``.
If `log_task_config` is True, the plugin will log the task configuration as a config dictionary
to the Weights and Biases logger.
Args:
name (str): The name for the Weights & Biases run.
logger_fn (Callable[..., run.Config[WandbLogger]]): A callable that returns a Config of ``WandbLogger``
log_task_config (bool, optional): Whether to log the task configuration to the logger.
Defaults to True.
Raises:
logging.warning: If the task is an instance of `run.Script`, as the plugin has no effect on such tasks.
"""
name: str
logger_fn: Callable[..., run.Config[WandbLogger]]
log_task_config: bool = True
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the wandb plugin."""
if isinstance(task, run.Script):
logging.warning(
f"The {self.__class__.__name__} will have no effect on the task as it's an instance of run.Script"
)
return
if "WANDB_API_KEY" in os.environ:
executor.env_vars["WANDB_API_KEY"] = os.environ["WANDB_API_KEY"]
if hasattr(task, "log") and hasattr(task.log, "wandb"):
task.log.wandb = self.logger_fn(name=self.name)
if self.log_task_config:
partial_config = yaml.safe_load(YamlSerializer().serialize(task))
partial_config["experiment"] = {
"id": self.experiment_id,
"task_name": self.name,
"executor": executor.info(),
"remote_directory": (
os.path.join(executor.tunnel.job_dir, Path(executor.job_dir).name)
if isinstance(executor, run.SlurmExecutor)
else None
),
"local_directory": executor.job_dir,
}
task.log.wandb.config = partial_config
else:
logging.warning(
f"The {self.__class__.__name__} will have no effect as WANDB_API_KEY environment variable is not set."
)
@dataclass(kw_only=True)
class ConfigValidationPlugin(run.Plugin):
"""
A plugin for validating a NeMo task with its executor.
This plugin is used to ensure that the NeMo environment, task, and executor meet certain criteria.
The validation checks include preemption, checkpoint directory,
serialization, and Weights and Biases (wandb) integration.
Attributes:
validate_preemption (bool): Whether to validate the preemption callback. If set to True, the plugin will
assert that the task has a `PreemptionCallback`. Defaults to True.
validate_checkpoint_dir (bool): Whether to validate the checkpoint directory. If set to True and the executor
is a `SlurmExecutor`, the plugin will assert that the task's log directory exists in the mounts
specified in the `SlurmExecutor`. Defaults to True.
validate_serialization (bool): Whether to validate task serialization. If set to True, the plugin will
assert that the task can be successfully serialized and deserialized using NeMo-Run's
`ZlibJSONSerializer`. Defaults to True.
validate_wandb (bool): Whether to validate Weights and Biases integration. If set to True, the plugin will
assert that the executor's environment variables contain a `WANDB_API_KEY`
and that NeMo Logger's `wandb` is set. Defaults to False.
validate_nodes_and_devices (bool): Whether to validate the number of devices and nodes. If set to True,
the plugin will assert that the task's trainer is configured to use the same number of nodes and devices
as the executor. Defaults to True.
"""
validate_preemption: bool = True
validate_checkpoint_dir: bool = True
validate_serialization: bool = True
validate_wandb: bool = False
validate_nodes_and_devices: bool = True
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the plugin to configure validation."""
assert isinstance(task, run.Partial)
logging.info(f"Validating {task.__fn_or_cls__.__qualname__} and {executor.__class__.__qualname__}.")
if self.validate_preemption:
logging.info("Validating preemption callback")
assert any(map(lambda callback: callback.__fn_or_cls__ == PreemptionCallback, task.trainer.callbacks))
if self.validate_checkpoint_dir:
if isinstance(executor, run.SlurmExecutor):
mounts = executor.container_mounts + ["/nemo_run"]
mounts = list(map(lambda m: m.split(":")[-1], mounts))
logging.info(f"Validating checkpoint dir {task.log.log_dir} exists in {mounts}")
assert task.log.log_dir
assert any(map(lambda mount: Path(mount) in Path(task.log.log_dir).parents, mounts))
if self.validate_serialization:
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
logging.info("Validating serialization/de-serialization of task")
serializer = ZlibJSONSerializer()
assert serializer.deserialize(serializer.serialize(task)) == task
if self.validate_wandb:
logging.info("Validating that Weights and Biases is enabled for task")
assert "WANDB_API_KEY" in executor.env_vars.keys()
assert task.log.wandb
if self.validate_nodes_and_devices:
logging.info("Validating that nodes and devices match for task and executor")
if isinstance(executor, run.SlurmExecutor):
assert task.trainer.num_nodes == executor.nodes
assert task.trainer.devices == executor.nproc_per_node()
@dataclass(kw_only=True)
class PerfEnvPlugin(run.Plugin):
"""
A plugin for setting up performance optimized environments.
Attributes:
enable_layernorm_sm_margin (bool): Set SM margin for TransformerEngine's Layernorm, so
in order to not block DP level communication overlap.
layernorm_sm_margin (int): The SM margin for TransformerEngine Layernorm.
enable_vboost (bool): Whether to steer more power towards tensor cores via
`sudo nvidia-smi boost-slider --vboost 1`. May not work on all systems.
"""
enable_layernorm_sm_margin: bool = True
layernorm_sm_margin: int = 16
enable_vboost: bool = False
nccl_pp_comm_chunksize: Optional[int] = None
gpu_sm100_or_newer: bool = False
user_buffer_registration: bool = False
def get_vboost_srun_cmd(self, nodes, job_dir):
"Create the vboost `sudo nvidia-smi boost-slider --vboost 1` command"
import shlex
vboost_cmd = " ".join(
[
"\n# Command 0: enable vboost\n\n",
"srun",
f"--ntasks={nodes}",
"--output",
os.path.join(job_dir, "vboost.out"),
"--error",
os.path.join(job_dir, "vboost.err"),
"bash -c ",
shlex.quote("sudo nvidia-smi boost-slider --vboost 1"),
],
)
return vboost_cmd
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Enable the performance environment settings"""
if task.trainer.strategy.__fn_or_cls__ == MegatronStrategy:
# Force program order kernel launch for TP, CP overlap
tp_size = task.trainer.strategy.tensor_model_parallel_size
cp_size = task.trainer.strategy.context_parallel_size
if self.gpu_sm100_or_newer and (tp_size > 1 or cp_size > 1):
executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "32"
elif (not self.gpu_sm100_or_newer) and (tp_size > 1 or cp_size > 1):
executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
# Set LayerNorm SM margin to support the overlap with LayerNorm kernel
if self.enable_layernorm_sm_margin:
executor.env_vars["NVTE_FWD_LAYERNORM_SM_MARGIN"] = str(self.layernorm_sm_margin)
executor.env_vars["NVTE_BWD_LAYERNORM_SM_MARGIN"] = str(self.layernorm_sm_margin)
# Set the chunk size of P2P communications. Using a large chunk size reduces the
# buffering overhead from the communication kernel execution time
pp_size = task.trainer.strategy.pipeline_model_parallel_size
if pp_size > 1 and self.nccl_pp_comm_chunksize is not None:
assert isinstance(self.nccl_pp_comm_chunksize, int) and self.nccl_pp_comm_chunksize > 1
executor.env_vars["NCCL_P2P_NET_CHUNKSIZE"] = str(self.nccl_pp_comm_chunksize)
# Enable high priority for NCCL communications
executor.env_vars["TORCH_NCCL_HIGH_PRIORITY"] = "1"
if self.user_buffer_registration:
# Enable NCCL NVLS ALGO, which could increase GPU memory usage
executor.env_vars["NCCL_NVLS_ENABLE"] = "1"
# This option makes NCCL to prefer SM efficient ALGOS if available
# With this option, NCCL will use NVLS if user buffer is registered
executor.env_vars["NCCL_CTA_POLICY"] = "1"
if "PYTORCH_CUDA_ALLOC_CONF" in executor.env_vars:
pytorch_cuda_alloc_conf = executor.env_vars["PYTORCH_CUDA_ALLOC_CONF"].split(',')
if "expandable_segments:True" in pytorch_cuda_alloc_conf:
logging.warning(
"PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True is not currently compatible with"
"user buffer registration. Removing expandable_segments:True from the list."
)
pytorch_cuda_alloc_conf.remove("expandable_segments:True")
executor.env_vars["PYTORCH_CUDA_ALLOC_CONF"] = ",".join(pytorch_cuda_alloc_conf)
if task.model.config.enable_cuda_graph and "PYTORCH_CUDA_ALLOC_CONF" in executor.env_vars:
del executor.env_vars["PYTORCH_CUDA_ALLOC_CONF"]
# Improve perf by steering power to tensor cores, may not work on all systems
if self.enable_vboost and isinstance(executor, run.SlurmExecutor):
vboost_cmd = self.get_vboost_srun_cmd(executor.nodes, executor.tunnel.job_dir)
executor.setup_lines = (
executor.setup_lines + vboost_cmd
if (executor.setup_lines and len(executor.setup_lines) > 0)
else vboost_cmd
)
@dataclass(kw_only=True)
class TritonCacheSetup(run.Plugin):
"""
A plugin for setting up Triton cache environment variables.
This should not be neccessay for Triton 3.2.0 and above.
"""
from nemo.core.utils.optional_libs import TRITON_AVAILABLE
if TRITON_AVAILABLE:
from triton import __version__ as triton_version
if triton_version <= "3.1.0":
def setup(self, task: run.Partial | run.Script, executor: run.Executor):
"""Set up the Triton cache environment variables."""
executor.env_vars["TRITON_CACHE_DIR"] = executor.job_dir + "triton_cahce"
executor.env_vars["TRITON_CACHE_MANAGER"] = (
"megatron.core.ssm.triton_cache_manager:ParallelFileCacheManager"
)