|
""" |
|
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914 |
|
""" |
|
import torch |
|
from peft import PeftModel |
|
from transformers import PreTrainedModel |
|
|
|
|
|
def patch_neft(alpha, model): |
|
embeddings = None |
|
if isinstance(model, PreTrainedModel): |
|
embeddings = model.get_input_embeddings() |
|
if isinstance(model, PeftModel): |
|
embeddings = model.base_model.get_input_embeddings() |
|
if not embeddings: |
|
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") |
|
embeddings.noisy_embedding_alpha = alpha |
|
old_forward = embeddings.forward |
|
|
|
|
|
|
|
bound_method = neft_forward.__get__( |
|
embeddings, embeddings.__class__ |
|
) |
|
setattr(embeddings, "forward", bound_method) |
|
|
|
embeddings._old_forward = old_forward |
|
return model |
|
|
|
|
|
def unpatch_neft(model): |
|
embeddings = None |
|
if isinstance(model, PreTrainedModel): |
|
embeddings = model.get_input_embeddings() |
|
if isinstance(model, PeftModel): |
|
embeddings = model.base_model.get_input_embeddings() |
|
if not embeddings: |
|
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") |
|
if hasattr(embeddings, "_old_forward"): |
|
embeddings.forward = embeddings._old_forward |
|
del embeddings._old_forward |
|
del embeddings.noisy_embedding_alpha |
|
|
|
|
|
def neft_forward(self, inputs: torch.Tensor): |
|
embeddings = self._old_forward(inputs) |
|
|
|
if self.training: |
|
dims = torch.tensor(embeddings.size(1) * embeddings.size(2)) |
|
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims) |
|
embeddings = embeddings + torch.zeros_like(embeddings).uniform_( |
|
-mag_norm, mag_norm |
|
) |
|
|
|
return embeddings |
|
|
|
|
|
def pretrain_hook(cfg, trainer): |
|
if cfg.noisy_embedding_alpha: |
|
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model) |
|
|
|
|
|
def post_train_hook(cfg, trainer): |
|
if cfg.noisy_embedding_alpha: |
|
unpatch_neft(trainer.model) |
|
|