lora+ support (#1352)
Browse files* lora+ support
* optimizer should default to None
* include mit license
src/axolotl/core/trainer_builder.py
CHANGED
@@ -27,8 +27,10 @@ from transformers import (
|
|
27 |
TrainingArguments,
|
28 |
)
|
29 |
from transformers.trainer_utils import seed_worker
|
|
|
30 |
from trl import DPOTrainer
|
31 |
|
|
|
32 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
33 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
34 |
from axolotl.utils.callbacks import (
|
@@ -54,6 +56,9 @@ from axolotl.utils.schedulers import (
|
|
54 |
get_cosine_schedule_with_warmup_decay_constant,
|
55 |
)
|
56 |
|
|
|
|
|
|
|
57 |
try:
|
58 |
import torch._dynamo # pylint: disable=ungrouped-imports
|
59 |
except ImportError:
|
@@ -179,6 +184,13 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
179 |
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
180 |
},
|
181 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
|
184 |
class AxolotlTrainer(Trainer):
|
@@ -203,6 +215,33 @@ class AxolotlTrainer(Trainer):
|
|
203 |
super().__init__(*_args, **kwargs)
|
204 |
self.train_data_collator = self.data_collator
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def create_scheduler(
|
207 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
208 |
):
|
@@ -915,6 +954,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
915 |
training_arguments_kwargs["optim"] = (
|
916 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
917 |
)
|
|
|
|
|
|
|
|
|
918 |
training_arguments_kwargs["lr_scheduler_type"] = (
|
919 |
self.cfg.lr_scheduler
|
920 |
if self.cfg.lr_scheduler
|
|
|
27 |
TrainingArguments,
|
28 |
)
|
29 |
from transformers.trainer_utils import seed_worker
|
30 |
+
from transformers.utils import is_sagemaker_mp_enabled
|
31 |
from trl import DPOTrainer
|
32 |
|
33 |
+
from axolotl.loraplus import create_loraplus_optimizer
|
34 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
35 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
36 |
from axolotl.utils.callbacks import (
|
|
|
56 |
get_cosine_schedule_with_warmup_decay_constant,
|
57 |
)
|
58 |
|
59 |
+
if is_sagemaker_mp_enabled():
|
60 |
+
import smdistributed.modelparallel.torch as smp
|
61 |
+
|
62 |
try:
|
63 |
import torch._dynamo # pylint: disable=ungrouped-imports
|
64 |
except ImportError:
|
|
|
184 |
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
185 |
},
|
186 |
)
|
187 |
+
loraplus_lr_ratio: Optional[float] = field(
|
188 |
+
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
189 |
+
)
|
190 |
+
loraplus_lr_embedding: Optional[float] = field(
|
191 |
+
default=1e-6,
|
192 |
+
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
193 |
+
)
|
194 |
|
195 |
|
196 |
class AxolotlTrainer(Trainer):
|
|
|
215 |
super().__init__(*_args, **kwargs)
|
216 |
self.train_data_collator = self.data_collator
|
217 |
|
218 |
+
def create_optimizer(self):
|
219 |
+
if self.args.loraplus_lr_ratio is None:
|
220 |
+
return super().create_optimizer()
|
221 |
+
|
222 |
+
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
223 |
+
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
224 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
225 |
+
self.args,
|
226 |
+
)
|
227 |
+
|
228 |
+
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
229 |
+
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
230 |
+
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
231 |
+
opt_model,
|
232 |
+
optimizer_cls,
|
233 |
+
optimizer_kwargs,
|
234 |
+
loraplus_lr_ratio,
|
235 |
+
loraplus_lr_embedding,
|
236 |
+
)
|
237 |
+
|
238 |
+
if is_sagemaker_mp_enabled():
|
239 |
+
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
240 |
+
self.optimizer
|
241 |
+
)
|
242 |
+
|
243 |
+
return self.optimizer
|
244 |
+
|
245 |
def create_scheduler(
|
246 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
247 |
):
|
|
|
954 |
training_arguments_kwargs["optim"] = (
|
955 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
956 |
)
|
957 |
+
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
958 |
+
training_arguments_kwargs[
|
959 |
+
"loraplus_lr_embedding"
|
960 |
+
] = self.cfg.loraplus_lr_embedding
|
961 |
training_arguments_kwargs["lr_scheduler_type"] = (
|
962 |
self.cfg.lr_scheduler
|
963 |
if self.cfg.lr_scheduler
|
src/axolotl/loraplus.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module for LoRA+"""
|
2 |
+
|
3 |
+
# MIT License
|
4 |
+
#
|
5 |
+
# Copyright (c) 2024 nikhil-ghosh-berkeley
|
6 |
+
# https://github.com/nikhil-ghosh-berkeley/loraplus
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from functools import reduce
|
10 |
+
|
11 |
+
from peft.tuners import lora
|
12 |
+
from torch import nn
|
13 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
14 |
+
from transformers.trainer_pt_utils import get_parameter_names
|
15 |
+
|
16 |
+
LOG = logging.getLogger("axolotl.loraplus")
|
17 |
+
|
18 |
+
|
19 |
+
def get_module(name, opt_model):
|
20 |
+
"""
|
21 |
+
Retrieve a module from a model using its parameter name.
|
22 |
+
Args:
|
23 |
+
name (str): Full name of the parameter, typically including module path.
|
24 |
+
opt_model (torch.nn.Module): The model from which to retrieve the module.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Module corresponding to the given name.
|
28 |
+
"""
|
29 |
+
parent_idx = 2 if "lora" in name else 1
|
30 |
+
module_names = name.split(sep=".")[:-parent_idx]
|
31 |
+
module = reduce(getattr, module_names, opt_model)
|
32 |
+
return module
|
33 |
+
|
34 |
+
|
35 |
+
def create_loraplus_optimizer(
|
36 |
+
opt_model,
|
37 |
+
optimizer_cls,
|
38 |
+
optimizer_kwargs,
|
39 |
+
loraplus_lr_ratio,
|
40 |
+
loraplus_lr_embedding=None,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
opt_model (torch.nn.Module): The model for which the optimizer is being created.
|
47 |
+
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
|
48 |
+
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
|
49 |
+
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
|
50 |
+
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
|
54 |
+
"""
|
55 |
+
|
56 |
+
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
|
57 |
+
|
58 |
+
if loraplus_lr_embedding is None:
|
59 |
+
loraplus_lr_embedding = 1e-6
|
60 |
+
|
61 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
62 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
63 |
+
param_groups = {
|
64 |
+
"groupA": {},
|
65 |
+
"groupB": {},
|
66 |
+
"groupB_no_decay": {},
|
67 |
+
"embedding": {},
|
68 |
+
}
|
69 |
+
|
70 |
+
for name, param in opt_model.named_parameters():
|
71 |
+
if not param.requires_grad:
|
72 |
+
continue
|
73 |
+
|
74 |
+
module = get_module(name, opt_model)
|
75 |
+
if isinstance(module, lora.Embedding):
|
76 |
+
param_groups["embedding"][name] = param
|
77 |
+
elif "lora_B" in name or param.ndim == 1:
|
78 |
+
if name in decay_parameters:
|
79 |
+
param_groups["groupB"][name] = param
|
80 |
+
else:
|
81 |
+
param_groups["groupB_no_decay"][name] = param
|
82 |
+
else:
|
83 |
+
param_groups["groupA"][name] = param
|
84 |
+
|
85 |
+
assigned_param_groups = ""
|
86 |
+
for group, group_params in param_groups.items():
|
87 |
+
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
|
88 |
+
LOG.info(assigned_param_groups)
|
89 |
+
|
90 |
+
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
91 |
+
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
|
92 |
+
|
93 |
+
optimizer_grouped_parameters = [
|
94 |
+
{
|
95 |
+
"params": list(param_groups["groupA"].values()),
|
96 |
+
"weight_decay": weight_decay,
|
97 |
+
"lr": lr,
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"params": list(param_groups["embedding"].values()),
|
101 |
+
"weight_decay": weight_decay,
|
102 |
+
"lr": loraplus_lr_embedding,
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"params": list(param_groups["groupB"].values()),
|
106 |
+
"weight_decay": weight_decay,
|
107 |
+
"lr": lr * loraplus_lr_ratio,
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"params": list(param_groups["groupB_no_decay"].values()),
|
111 |
+
"weight_decay": 0.0,
|
112 |
+
"lr": lr * loraplus_lr_ratio,
|
113 |
+
},
|
114 |
+
]
|
115 |
+
|
116 |
+
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
117 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
118 |
+
import bitsandbytes
|
119 |
+
|
120 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
121 |
+
|
122 |
+
skipped = 0
|
123 |
+
for module in opt_model.modules():
|
124 |
+
if isinstance(module, nn.Embedding):
|
125 |
+
skipped += sum(
|
126 |
+
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
|
127 |
+
)
|
128 |
+
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
129 |
+
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
130 |
+
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
131 |
+
LOG.info(f"skipped: {skipped/2**20}M params")
|
132 |
+
|
133 |
+
return optimizer
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -183,6 +183,17 @@ class LoraConfig(BaseModel):
|
|
183 |
gptq: Optional[bool] = None
|
184 |
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
merge_lora: Optional[bool] = None
|
187 |
|
188 |
@model_validator(mode="before")
|
|
|
183 |
gptq: Optional[bool] = None
|
184 |
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
185 |
|
186 |
+
loraplus_lr_ratio: Optional[float] = Field(
|
187 |
+
default=None,
|
188 |
+
metadata={
|
189 |
+
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
190 |
+
},
|
191 |
+
)
|
192 |
+
loraplus_lr_embedding: Optional[float] = Field(
|
193 |
+
default=1e-6,
|
194 |
+
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
195 |
+
)
|
196 |
+
|
197 |
merge_lora: Optional[bool] = None
|
198 |
|
199 |
@model_validator(mode="before")
|