subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks.callback import Callback
from megatron.core.transformer import TransformerConfig
class DeepEPCallback(Callback):
"""
A PyTorch Lightning callback to enable DeepEP if the hardware is supported.
Per official documentation https://github.com/deepseek-ai/DeepEP,
DeepEP is supported for Ampere (SM80) and Hopper (SM90) GPUs.
Adding this callback is equivalent to setting the following flags in the recipe function:
recipe.model.config.moe_token_dispatcher_type = "flex"
recipe.model.config.moe_enable_deepep = True
recipe.model.config.moe_shared_expert_overlap = False
Since the recipe function may be run on a different machine, this callback is needed so that
configs are set during run time.
"""
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
"""Enable DeepEP if GPU is Ampere or Hopper"""
if torch.cuda.get_device_properties(0).major not in [8, 9]:
return
if hasattr(trainer.model, "config") and isinstance(trainer.model.config, TransformerConfig):
self._apply_deepep_cfgs(trainer.model.config)
if hasattr(trainer.model, '__io__'):
self._apply_deepep_cfgs(trainer.model.__io__.config)
def _apply_deepep_cfgs(self, dest_cfg):
# apply optimizations into dest_cfg
dest_cfg.moe_token_dispatcher_type = "flex"
dest_cfg.moe_enable_deepep = True
dest_cfg.moe_shared_expert_overlap = False