|
|
|
from dataclasses import dataclass, field |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from swift.utils.logger import get_logger |
|
from .utils import SwiftAdapter, SwiftConfig, SwiftOutput |
|
|
|
logger = get_logger() |
|
|
|
|
|
@dataclass |
|
class NEFTuneConfig(SwiftConfig): |
|
""" |
|
The configuration class for the NEFTune module. |
|
|
|
NEFTune adds slightly noises to embedding outputs. |
|
See https://arxiv.org/abs/2310.05914 |
|
|
|
Args: |
|
noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0 |
|
""" |
|
noise_alpha: float = field(default=5.0, metadata={'help': 'The noise alpha value used for the NEFTune'}) |
|
|
|
def __post_init__(self): |
|
from .mapping import SwiftTuners |
|
self.swift_type = SwiftTuners.NEFTUNE |
|
|
|
|
|
class NEFTune(SwiftAdapter): |
|
|
|
@staticmethod |
|
def prepare_model(model: nn.Module, config: NEFTuneConfig, adapter_name: str) -> SwiftOutput: |
|
"""Prepare a model with `NEFTuneConfig`""" |
|
for sub_module in model.modules(): |
|
if isinstance(sub_module, torch.nn.Embedding): |
|
|
|
def neftune_hook(module, args, output): |
|
if module.training and getattr(module, 'nef_activated'): |
|
dims = torch.tensor(output.size(-1) * output.size(-2)) |
|
mag_norm = config.noise_alpha / torch.sqrt(dims) |
|
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) |
|
return output |
|
|
|
if hasattr(sub_module, 'nef_activated'): |
|
raise ValueError('NEFTune does not support a second tuner.') |
|
|
|
sub_module.register_forward_hook(neftune_hook) |
|
sub_module.nef_activated = True |
|
|
|
def state_dict_callback(state_dict, adapter_name, **kwargs): |
|
return state_dict |
|
|
|
def mark_trainable_callback(model): |
|
return |
|
|
|
return SwiftOutput( |
|
config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) |
|
|
|
@staticmethod |
|
def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): |
|
for sub_module in module.modules(): |
|
if isinstance(sub_module, torch.nn.Embedding): |
|
sub_module.nef_activated = activate |
|
|
|
@staticmethod |
|
def freeze_model(): |
|
return False |
|
|
|
@staticmethod |
|
def has_additional_modules(): |
|
return False |
|
|