dg-kalle commited on
Commit
ef24342
1 Parent(s): 5ea3aa3

fix: switch to using the HuggingFace Transformers NEFT implementation (#941)

Browse files

* fix: switch to using the HuggingFace Transformers NEFT implementation

* linter

* add support for noisy_embedding_alpha with a warning about it being renamed

* restore pre/posttrain_hooks

* move validation of NEFT noise alpha into validate_config()

* linter

README.md CHANGED
@@ -774,7 +774,7 @@ max_grad_norm:
774
  # Augmentation techniques
775
  # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
776
  # currently only supported on Llama and Mistral
777
- noisy_embedding_alpha:
778
 
779
  # Whether to bettertransformers
780
  flash_optimum:
 
774
  # Augmentation techniques
775
  # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
776
  # currently only supported on Llama and Mistral
777
+ neftune_noise_alpha:
778
 
779
  # Whether to bettertransformers
780
  flash_optimum:
src/axolotl/core/trainer_builder.py CHANGED
@@ -712,6 +712,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
712
  training_arguments_kwargs
713
  )
714
  training_arguments_kwargs["model_type"] = self.cfg.model_config_type
 
 
 
 
 
 
715
  training_args = (
716
  AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
717
  **training_arguments_kwargs,
 
712
  training_arguments_kwargs
713
  )
714
  training_arguments_kwargs["model_type"] = self.cfg.model_config_type
715
+
716
+ if self.cfg.neftune_noise_alpha is not None:
717
+ training_arguments_kwargs[
718
+ "neftune_noise_alpha"
719
+ ] = self.cfg.neftune_noise_alpha
720
+
721
  training_args = (
722
  AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
723
  **training_arguments_kwargs,
src/axolotl/monkeypatch/neft_embeddings.py DELETED
@@ -1,65 +0,0 @@
1
- """
2
- patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
3
- """
4
- import torch
5
- from peft import PeftModel
6
- from transformers import PreTrainedModel
7
-
8
-
9
- def patch_neft(alpha, model):
10
- embeddings = None
11
- if isinstance(model, PreTrainedModel):
12
- embeddings = model.get_input_embeddings()
13
- if isinstance(model, PeftModel):
14
- embeddings = model.base_model.get_input_embeddings()
15
- if not embeddings:
16
- raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
17
- embeddings.noisy_embedding_alpha = alpha
18
- old_forward = embeddings.forward
19
-
20
- # This hack seems to be needed to properly use a custom forward pass
21
- # all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
22
- bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
23
- embeddings, embeddings.__class__
24
- )
25
- setattr(embeddings, "forward", bound_method)
26
-
27
- embeddings._old_forward = old_forward # pylint: disable=protected-access
28
- return model
29
-
30
-
31
- def unpatch_neft(model):
32
- embeddings = None
33
- if isinstance(model, PreTrainedModel):
34
- embeddings = model.get_input_embeddings()
35
- if isinstance(model, PeftModel):
36
- embeddings = model.base_model.get_input_embeddings()
37
- if not embeddings:
38
- raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
39
- if hasattr(embeddings, "_old_forward"):
40
- embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
41
- del embeddings._old_forward # pylint: disable=protected-access
42
- del embeddings.noisy_embedding_alpha
43
-
44
-
45
- def neft_forward(self, inputs: torch.Tensor):
46
- embeddings = self._old_forward(inputs) # pylint: disable=protected-access
47
-
48
- if self.training:
49
- dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
50
- mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
51
- embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
52
- -mag_norm, mag_norm
53
- )
54
-
55
- return embeddings
56
-
57
-
58
- def pretrain_hook(cfg, trainer):
59
- if cfg.noisy_embedding_alpha:
60
- trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
61
-
62
-
63
- def post_train_hook(cfg, trainer):
64
- if cfg.noisy_embedding_alpha:
65
- unpatch_neft(trainer.model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/train.py CHANGED
@@ -16,7 +16,6 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
16
 
17
  from axolotl.common.cli import TrainerCliArgs
18
  from axolotl.logging_config import configure_logging
19
- from axolotl.monkeypatch import neft_embeddings
20
  from axolotl.utils.dict import DictDefault
21
  from axolotl.utils.freeze import freeze_parameters_except
22
  from axolotl.utils.models import load_model, load_tokenizer
@@ -180,21 +179,19 @@ def train(
180
  return model, tokenizer
181
 
182
 
183
- def pretrain_hooks(cfg, trainer):
184
  """
185
  Run hooks right before kicking off the training
186
  :param cfg:
187
  :param trainer:
188
  :return:
189
  """
190
- neft_embeddings.pretrain_hook(cfg, trainer)
191
 
192
 
193
- def post_train_hooks(cfg, trainer):
194
  """
195
  Run hooks right after training completes
196
  :param cfg:
197
  :param trainer:
198
  :return:
199
  """
200
- neft_embeddings.post_train_hook(cfg, trainer)
 
16
 
17
  from axolotl.common.cli import TrainerCliArgs
18
  from axolotl.logging_config import configure_logging
 
19
  from axolotl.utils.dict import DictDefault
20
  from axolotl.utils.freeze import freeze_parameters_except
21
  from axolotl.utils.models import load_model, load_tokenizer
 
179
  return model, tokenizer
180
 
181
 
182
+ def pretrain_hooks(_cfg, _trainer):
183
  """
184
  Run hooks right before kicking off the training
185
  :param cfg:
186
  :param trainer:
187
  :return:
188
  """
 
189
 
190
 
191
+ def post_train_hooks(_cfg, _trainer):
192
  """
193
  Run hooks right after training completes
194
  :param cfg:
195
  :param trainer:
196
  :return:
197
  """
 
src/axolotl/utils/config.py CHANGED
@@ -434,6 +434,20 @@ def validate_config(cfg):
434
  "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
435
  )
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  # TODO
438
  # MPT 7b
439
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
434
  "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
435
  )
436
 
437
+ if cfg.noisy_embedding_alpha is not None:
438
+ # Deprecated, use neftune_noise_alpha
439
+ LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
440
+ if cfg.neftune_noise_alpha is None:
441
+ cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
442
+ else:
443
+ # User is providing both; bail and have them sort out their settings
444
+ raise ValueError(
445
+ "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
446
+ )
447
+
448
+ if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
449
+ raise ValueError("neftune_noise_alpha must be > 0.0")
450
+
451
  # TODO
452
  # MPT 7b
453
  # https://github.com/facebookresearch/bitsandbytes/issues/25