Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # pylint: disable=missing-class-docstring | |
| # pylint: disable=missing-function-docstring | |
| from contextlib import contextmanager | |
| from typing import TYPE_CHECKING, Generator, Literal, Optional, TypeVar | |
| import torch | |
| from lightning.fabric.plugins.precision import MixedPrecision | |
| from torch import nn | |
| from torch.optim import Optimizer | |
| from nemo.lightning.fabric.conversion import to_fabric | |
| from nemo.lightning.pytorch.plugins.mixed_precision import ( | |
| DtypeConfig, | |
| MegatronMixedPrecision, | |
| get_optim_config, | |
| update_config_with_dtype_overrides, | |
| ) | |
| from nemo.utils import logging | |
| if TYPE_CHECKING: | |
| from megatron.core.model_parallel_config import ModelParallelConfig | |
| AnyT = TypeVar("AnyT") | |
| ConfigT = TypeVar("ConfigT", bound="ModelParallelConfig") | |
| class FabricMegatronMixedPrecision(MixedPrecision): | |
| """Fabric plugin for mixed precision training with Megatron models. | |
| Handles precision conversions and mixed precision training settings | |
| in the Fabric training framework. | |
| """ | |
| def __init__( | |
| self, | |
| precision: Literal["16-mixed", "bf16-mixed", "32"], | |
| params_dtype: torch.dtype = None, | |
| pipeline_dtype: torch.dtype = None, | |
| autocast_dtype: torch.dtype = None, | |
| autocast_enabled: bool = False, | |
| grad_reduce_in_fp32: bool = True, | |
| # fp8 related, | |
| fp8: str = None, | |
| fp8_recipe: Optional[str] = None, | |
| first_last_layers_bf16: bool = False, | |
| num_layers_at_start_in_bf16: int = 0, | |
| num_layers_at_end_in_bf16: int = 0, | |
| reuse_grad_buf_for_mxfp8_param_ag: bool = False, | |
| fp8_margin: int = 0, | |
| fp8_amax_history_len: int = 1, | |
| fp8_amax_compute_algo: str = "most_recent", | |
| fp8_wgrad: bool = True, | |
| fp8_dot_product_attention: bool = False, | |
| fp8_multi_head_attention: bool = False, | |
| fp8_params: bool = None, | |
| fp8_param_gather: bool = None, | |
| fp16_loss_scale: float = None, | |
| fp16_initial_loss_scale: float = 4294967296, | |
| fp16_min_loss_scale: float = 1.0, | |
| fp16_loss_scale_window: int = 1000, | |
| fp16_hysteresis: int = 2, | |
| ) -> None: | |
| if fp8_params is not None: | |
| logging.warning( | |
| "fp8_params is deprecated and will be removed in a future release, use fp8_param_gather instead" | |
| ) | |
| if fp8_param_gather is not None and fp8_param_gather != fp8_params: | |
| raise ValueError( | |
| "Getting conflicting values for fp8_params and fp8_param_gather. Please only set fp8_param_gather." | |
| ) | |
| fp8_param_gather = fp8_params | |
| elif fp8_param_gather is None: | |
| fp8_param_gather = False | |
| if isinstance(precision, int): | |
| precision = str(precision) | |
| dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32 | |
| self.dtype_config = DtypeConfig( | |
| fp32=precision in ['fp32', '32'], | |
| fp16=precision in ['fp16', 'fp16-mixed', '16', '16-mixed'], | |
| bf16=precision in ['bf16', 'bf16-mixed'], | |
| params_dtype=params_dtype or torch.float32, | |
| pipeline_dtype=pipeline_dtype or dtype, | |
| autocast_dtype=autocast_dtype or dtype, | |
| autocast_enabled=autocast_enabled, | |
| grad_reduce_in_fp32=grad_reduce_in_fp32, | |
| fp8=fp8, | |
| fp8_recipe=fp8_recipe, | |
| first_last_layers_bf16=first_last_layers_bf16, | |
| num_layers_at_start_in_bf16=num_layers_at_start_in_bf16, | |
| num_layers_at_end_in_bf16=num_layers_at_end_in_bf16, | |
| reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag, | |
| fp8_margin=fp8_margin, | |
| fp8_amax_history_len=fp8_amax_history_len, | |
| fp8_amax_compute_algo=fp8_amax_compute_algo, | |
| fp8_wgrad=fp8_wgrad, | |
| fp8_dot_product_attention=fp8_dot_product_attention, | |
| fp8_multi_head_attention=fp8_multi_head_attention, | |
| fp8_param=fp8_param_gather, | |
| fp8_param_gather=fp8_param_gather, | |
| # fp16 loss scale | |
| loss_scale=fp16_loss_scale, | |
| initial_loss_scale=fp16_initial_loss_scale, | |
| min_loss_scale=fp16_min_loss_scale, | |
| loss_scale_window=fp16_loss_scale_window, | |
| hysteresis=fp16_hysteresis, | |
| ) | |
| if self.dtype_config.fp16: | |
| self.precision = "16-mixed" | |
| elif self.dtype_config.bf16: | |
| self.precision = "bf16-mixed" | |
| else: | |
| self.precision = "32-true" | |
| self.scaler = None | |
| def convert_input(self, data: AnyT) -> AnyT: | |
| """Convert model inputs (forward) to the floating point precision type of this plugin. | |
| Note: MegatronStrategy will take care of only doing this when: | |
| mpu.is_pipeline_first_stage() | |
| """ | |
| return data | |
| def convert_output(self, data: AnyT) -> AnyT: | |
| """Convert outputs to the floating point precision type expected after model's forward. | |
| Note: MegatronStrategy will take care of only doing this when: | |
| mpu.is_pipeline_first_stage() | |
| """ | |
| return data | |
| def convert_config(self, config: ConfigT) -> ConfigT: | |
| """Convert the config to the precision type this plugin handles. | |
| This is optional and depends on the precision limitations during optimization. | |
| """ | |
| return update_config_with_dtype_overrides(self.dtype_config, config) | |
| def convert_module(self, module: nn.Module) -> nn.Module: | |
| """Convert the module parameters to the precision type this plugin handles. | |
| This is optional and depends on the precision limitations during optimization. | |
| """ | |
| if not hasattr(module, "module"): | |
| return module | |
| from megatron.core.transformer.module import Float16Module | |
| from megatron.core.utils import get_model_config | |
| if self.dtype_config.fp16 or self.dtype_config.bf16: | |
| # Patch config options | |
| config = get_model_config(module.module) | |
| config.fp16 = self.dtype_config.fp16 | |
| config.bf16 = self.dtype_config.bf16 | |
| # Avoid rewrapping the module if it's already of type Float16Module | |
| if hasattr(module, "module"): | |
| if not isinstance(module.module, Float16Module): | |
| module.module = Float16Module(config, module.module) | |
| elif not isinstance(module, Float16Module): | |
| module = Float16Module(config, module) | |
| return module | |
| def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: | |
| """Convert the optimizer parameters to the precision type this plugin handles. | |
| This is optional and depends on the precision limitations during optimization. | |
| """ | |
| for optim_config in get_optim_config(optimizer): | |
| assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch" | |
| assert optim_config.fp16 == self.dtype_config.fp16, "FP16 model/optim config mismatch" | |
| return optimizer | |
| def forward_context(self) -> Generator[None, None, None]: | |
| """No explicit precision casting. Inputs are supposed to be manually casted.""" | |
| try: | |
| yield | |
| finally: | |
| pass | |
| def _convert_megatron_mixed_precision(plugin: MegatronMixedPrecision) -> FabricMegatronMixedPrecision: | |
| return FabricMegatronMixedPrecision( | |
| precision=plugin.precision, | |
| ) | |