Spaces:
Running
Running
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. | |
# | |
# This code is inspired by the HuggingFace's Transformers and PEFT library. | |
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py | |
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import inspect | |
from functools import partial | |
from types import MethodType | |
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple | |
import torch | |
from ...extras.constants import LAYERNORM_NAMES | |
from ...extras.logging import get_logger | |
if TYPE_CHECKING: | |
from transformers import PreTrainedModel | |
from ...hparams import ModelArguments | |
logger = get_logger(__name__) | |
def _gradient_checkpointing_enable( | |
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None | |
) -> None: | |
r""" | |
Activates gradient checkpointing for the current model. | |
Modification of the original method to enable gradient checkpointing for block-wise optimizer. | |
""" | |
from torch.utils.checkpoint import checkpoint | |
if not self.supports_gradient_checkpointing: | |
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) | |
if gradient_checkpointing_kwargs is None: | |
gradient_checkpointing_kwargs = {"use_reentrant": True} | |
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) | |
def custom_gradient_checkpointing_func(func, *args, **kwargs): | |
module: "torch.nn.Module" = func.__self__ | |
if any(param.requires_grad for param in module.parameters()): | |
for arg in args: | |
if torch.is_tensor(arg) and torch.is_floating_point(arg): | |
arg.requires_grad_(True) | |
return gradient_checkpointing_func(func, *args, **kwargs) | |
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format | |
self.apply(partial(self._set_gradient_checkpointing, value=True)) | |
self.enable_input_require_grads() | |
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") | |
else: # have already enabled input require gradients | |
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) | |
def _fp32_forward_post_hook( | |
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" | |
) -> "torch.Tensor": | |
return output.to(torch.float32) | |
def prepare_model_for_training( | |
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head" | |
) -> None: | |
r""" | |
Includes: | |
(1) cast the layernorm in fp32 | |
(2) make output embedding layer require grads | |
(3) add the upcasting of the lm_head in fp32 | |
""" | |
if model_args.upcast_layernorm: | |
logger.info("Upcasting layernorm weights in float32.") | |
for name, param in model.named_parameters(): | |
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): | |
param.data = param.data.to(torch.float32) | |
if not model_args.disable_gradient_checkpointing: | |
if not getattr(model, "supports_gradient_checkpointing", False): | |
logger.warning("Current model does not support gradient checkpointing.") | |
else: | |
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet) | |
# According to: https://github.com/huggingface/transformers/issues/28339 | |
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) | |
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) | |
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled | |
logger.info("Gradient checkpointing enabled.") | |
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output: | |
logger.info("Upcasting lm_head outputs in float32.") | |
output_layer = getattr(model, output_layer_name) | |
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: | |
output_layer.register_forward_hook(_fp32_forward_post_hook) | |