Enxin's picture
Upload folder using huggingface_hub
96fe658 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
import torch
from torch import nn
from swift.utils.logger import get_logger
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput
logger = get_logger()
@dataclass
class NEFTuneConfig(SwiftConfig):
"""
The configuration class for the NEFTune module.
NEFTune adds slightly noises to embedding outputs.
See https://arxiv.org/abs/2310.05914
Args:
noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0
"""
noise_alpha: float = field(default=5.0, metadata={'help': 'The noise alpha value used for the NEFTune'})
def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.NEFTUNE
class NEFTune(SwiftAdapter):
@staticmethod
def prepare_model(model: nn.Module, config: NEFTuneConfig, adapter_name: str) -> SwiftOutput:
"""Prepare a model with `NEFTuneConfig`"""
for sub_module in model.modules():
if isinstance(sub_module, torch.nn.Embedding):
def neftune_hook(module, args, output):
if module.training and getattr(module, 'nef_activated'):
dims = torch.tensor(output.size(-1) * output.size(-2))
mag_norm = config.noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
if hasattr(sub_module, 'nef_activated'):
raise ValueError('NEFTune does not support a second tuner.')
sub_module.register_forward_hook(neftune_hook)
sub_module.nef_activated = True
def state_dict_callback(state_dict, adapter_name, **kwargs):
return state_dict
def mark_trainable_callback(model):
return
return SwiftOutput(
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
for sub_module in module.modules():
if isinstance(sub_module, torch.nn.Embedding):
sub_module.nef_activated = activate
@staticmethod
def freeze_model():
return False
@staticmethod
def has_additional_modules():
return False