winglian commited on
Commit
decb66e
1 Parent(s): 4d09b42

lora+ support (#1352)

Browse files

* lora+ support

* optimizer should default to None

* include mit license

src/axolotl/core/trainer_builder.py CHANGED
@@ -27,8 +27,10 @@ from transformers import (
27
  TrainingArguments,
28
  )
29
  from transformers.trainer_utils import seed_worker
 
30
  from trl import DPOTrainer
31
 
 
32
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
33
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
34
  from axolotl.utils.callbacks import (
@@ -54,6 +56,9 @@ from axolotl.utils.schedulers import (
54
  get_cosine_schedule_with_warmup_decay_constant,
55
  )
56
 
 
 
 
57
  try:
58
  import torch._dynamo # pylint: disable=ungrouped-imports
59
  except ImportError:
@@ -179,6 +184,13 @@ class AxolotlTrainingArguments(TrainingArguments):
179
  "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
180
  },
181
  )
 
 
 
 
 
 
 
182
 
183
 
184
  class AxolotlTrainer(Trainer):
@@ -203,6 +215,33 @@ class AxolotlTrainer(Trainer):
203
  super().__init__(*_args, **kwargs)
204
  self.train_data_collator = self.data_collator
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def create_scheduler(
207
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
208
  ):
@@ -915,6 +954,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
915
  training_arguments_kwargs["optim"] = (
916
  self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
917
  )
 
 
 
 
918
  training_arguments_kwargs["lr_scheduler_type"] = (
919
  self.cfg.lr_scheduler
920
  if self.cfg.lr_scheduler
 
27
  TrainingArguments,
28
  )
29
  from transformers.trainer_utils import seed_worker
30
+ from transformers.utils import is_sagemaker_mp_enabled
31
  from trl import DPOTrainer
32
 
33
+ from axolotl.loraplus import create_loraplus_optimizer
34
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
35
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
36
  from axolotl.utils.callbacks import (
 
56
  get_cosine_schedule_with_warmup_decay_constant,
57
  )
58
 
59
+ if is_sagemaker_mp_enabled():
60
+ import smdistributed.modelparallel.torch as smp
61
+
62
  try:
63
  import torch._dynamo # pylint: disable=ungrouped-imports
64
  except ImportError:
 
184
  "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
185
  },
186
  )
187
+ loraplus_lr_ratio: Optional[float] = field(
188
+ default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
189
+ )
190
+ loraplus_lr_embedding: Optional[float] = field(
191
+ default=1e-6,
192
+ metadata={"help": "loraplus learning rate for lora embedding layers."},
193
+ )
194
 
195
 
196
  class AxolotlTrainer(Trainer):
 
215
  super().__init__(*_args, **kwargs)
216
  self.train_data_collator = self.data_collator
217
 
218
+ def create_optimizer(self):
219
+ if self.args.loraplus_lr_ratio is None:
220
+ return super().create_optimizer()
221
+
222
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
223
+ if self.optimizer is None: # pylint: disable=access-member-before-definition
224
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
225
+ self.args,
226
+ )
227
+
228
+ loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
229
+ loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
230
+ self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
231
+ opt_model,
232
+ optimizer_cls,
233
+ optimizer_kwargs,
234
+ loraplus_lr_ratio,
235
+ loraplus_lr_embedding,
236
+ )
237
+
238
+ if is_sagemaker_mp_enabled():
239
+ self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
240
+ self.optimizer
241
+ )
242
+
243
+ return self.optimizer
244
+
245
  def create_scheduler(
246
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
247
  ):
 
954
  training_arguments_kwargs["optim"] = (
955
  self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
956
  )
957
+ training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
958
+ training_arguments_kwargs[
959
+ "loraplus_lr_embedding"
960
+ ] = self.cfg.loraplus_lr_embedding
961
  training_arguments_kwargs["lr_scheduler_type"] = (
962
  self.cfg.lr_scheduler
963
  if self.cfg.lr_scheduler
src/axolotl/loraplus.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for LoRA+"""
2
+
3
+ # MIT License
4
+ #
5
+ # Copyright (c) 2024 nikhil-ghosh-berkeley
6
+ # https://github.com/nikhil-ghosh-berkeley/loraplus
7
+
8
+ import logging
9
+ from functools import reduce
10
+
11
+ from peft.tuners import lora
12
+ from torch import nn
13
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
14
+ from transformers.trainer_pt_utils import get_parameter_names
15
+
16
+ LOG = logging.getLogger("axolotl.loraplus")
17
+
18
+
19
+ def get_module(name, opt_model):
20
+ """
21
+ Retrieve a module from a model using its parameter name.
22
+ Args:
23
+ name (str): Full name of the parameter, typically including module path.
24
+ opt_model (torch.nn.Module): The model from which to retrieve the module.
25
+
26
+ Returns:
27
+ Module corresponding to the given name.
28
+ """
29
+ parent_idx = 2 if "lora" in name else 1
30
+ module_names = name.split(sep=".")[:-parent_idx]
31
+ module = reduce(getattr, module_names, opt_model)
32
+ return module
33
+
34
+
35
+ def create_loraplus_optimizer(
36
+ opt_model,
37
+ optimizer_cls,
38
+ optimizer_kwargs,
39
+ loraplus_lr_ratio,
40
+ loraplus_lr_embedding=None,
41
+ ):
42
+ """
43
+ Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
44
+
45
+ Args:
46
+ opt_model (torch.nn.Module): The model for which the optimizer is being created.
47
+ optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
48
+ optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
49
+ loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
50
+ loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
51
+
52
+ Returns:
53
+ An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
54
+ """
55
+
56
+ assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
57
+
58
+ if loraplus_lr_embedding is None:
59
+ loraplus_lr_embedding = 1e-6
60
+
61
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
62
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
63
+ param_groups = {
64
+ "groupA": {},
65
+ "groupB": {},
66
+ "groupB_no_decay": {},
67
+ "embedding": {},
68
+ }
69
+
70
+ for name, param in opt_model.named_parameters():
71
+ if not param.requires_grad:
72
+ continue
73
+
74
+ module = get_module(name, opt_model)
75
+ if isinstance(module, lora.Embedding):
76
+ param_groups["embedding"][name] = param
77
+ elif "lora_B" in name or param.ndim == 1:
78
+ if name in decay_parameters:
79
+ param_groups["groupB"][name] = param
80
+ else:
81
+ param_groups["groupB_no_decay"][name] = param
82
+ else:
83
+ param_groups["groupA"][name] = param
84
+
85
+ assigned_param_groups = ""
86
+ for group, group_params in param_groups.items():
87
+ assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
88
+ LOG.info(assigned_param_groups)
89
+
90
+ lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
91
+ weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
92
+
93
+ optimizer_grouped_parameters = [
94
+ {
95
+ "params": list(param_groups["groupA"].values()),
96
+ "weight_decay": weight_decay,
97
+ "lr": lr,
98
+ },
99
+ {
100
+ "params": list(param_groups["embedding"].values()),
101
+ "weight_decay": weight_decay,
102
+ "lr": loraplus_lr_embedding,
103
+ },
104
+ {
105
+ "params": list(param_groups["groupB"].values()),
106
+ "weight_decay": weight_decay,
107
+ "lr": lr * loraplus_lr_ratio,
108
+ },
109
+ {
110
+ "params": list(param_groups["groupB_no_decay"].values()),
111
+ "weight_decay": 0.0,
112
+ "lr": lr * loraplus_lr_ratio,
113
+ },
114
+ ]
115
+
116
+ optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
117
+ if optimizer_cls.__name__ == "Adam8bit":
118
+ import bitsandbytes
119
+
120
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
121
+
122
+ skipped = 0
123
+ for module in opt_model.modules():
124
+ if isinstance(module, nn.Embedding):
125
+ skipped += sum(
126
+ {p.data_ptr(): p.numel() for p in module.parameters()}.values()
127
+ )
128
+ LOG.info(f"skipped {module}: {skipped/2**20}M params")
129
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
130
+ LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
131
+ LOG.info(f"skipped: {skipped/2**20}M params")
132
+
133
+ return optimizer
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -183,6 +183,17 @@ class LoraConfig(BaseModel):
183
  gptq: Optional[bool] = None
184
  bnb_config_kwargs: Optional[Dict[str, Any]] = None
185
 
 
 
 
 
 
 
 
 
 
 
 
186
  merge_lora: Optional[bool] = None
187
 
188
  @model_validator(mode="before")
 
183
  gptq: Optional[bool] = None
184
  bnb_config_kwargs: Optional[Dict[str, Any]] = None
185
 
186
+ loraplus_lr_ratio: Optional[float] = Field(
187
+ default=None,
188
+ metadata={
189
+ "help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
190
+ },
191
+ )
192
+ loraplus_lr_embedding: Optional[float] = Field(
193
+ default=1e-6,
194
+ metadata={"help": "loraplus learning rate for lora embedding layers."},
195
+ )
196
+
197
  merge_lora: Optional[bool] = None
198
 
199
  @model_validator(mode="before")