LISA (#1469)
Browse files* add lisa support
* fix default and fix attribute traversal for layers
* improve lisa callback logging
* fix LISA by ensuring params are not frozen during __init__
* example config for lisa
---------
Co-authored-by: Aman Karmani <aman@tmm1.net>
examples/llama-2/lisa.yml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: NousResearch/Llama-2-7b-hf
|
2 |
+
model_type: LlamaForCausalLM
|
3 |
+
tokenizer_type: LlamaTokenizer
|
4 |
+
|
5 |
+
load_in_8bit: false
|
6 |
+
load_in_4bit: false
|
7 |
+
strict: false
|
8 |
+
|
9 |
+
datasets:
|
10 |
+
- path: teknium/GPT4-LLM-Cleaned
|
11 |
+
type: alpaca
|
12 |
+
dataset_prepared_path: last_run_prepared
|
13 |
+
val_set_size: 0.05
|
14 |
+
output_dir: ./lisa-out
|
15 |
+
|
16 |
+
sequence_len: 4096
|
17 |
+
sample_packing: true
|
18 |
+
pad_to_sequence_len: true
|
19 |
+
|
20 |
+
adapter:
|
21 |
+
lora_model_dir:
|
22 |
+
lora_r:
|
23 |
+
lora_alpha:
|
24 |
+
lora_dropout:
|
25 |
+
lora_target_linear:
|
26 |
+
lora_fan_in_fan_out:
|
27 |
+
|
28 |
+
lisa_n_layers: 4
|
29 |
+
lisa_step_interval: 20
|
30 |
+
lisa_layers_attribute: model.layers
|
31 |
+
|
32 |
+
wandb_project:
|
33 |
+
wandb_entity:
|
34 |
+
wandb_watch:
|
35 |
+
wandb_name:
|
36 |
+
wandb_log_model:
|
37 |
+
|
38 |
+
gradient_accumulation_steps: 2
|
39 |
+
micro_batch_size: 1
|
40 |
+
num_epochs: 1
|
41 |
+
optimizer: adamw_bnb_8bit
|
42 |
+
lr_scheduler: cosine
|
43 |
+
learning_rate: 5e-5 # recommendation from lisa paper for 7b
|
44 |
+
|
45 |
+
train_on_inputs: false
|
46 |
+
group_by_length: false
|
47 |
+
bf16: auto
|
48 |
+
fp16:
|
49 |
+
tf32: false
|
50 |
+
|
51 |
+
gradient_checkpointing: true
|
52 |
+
early_stopping_patience:
|
53 |
+
resume_from_checkpoint:
|
54 |
+
local_rank:
|
55 |
+
logging_steps: 1
|
56 |
+
xformers_attention:
|
57 |
+
flash_attention: true
|
58 |
+
flash_attn_cross_entropy: false
|
59 |
+
flash_attn_rms_norm: true
|
60 |
+
flash_attn_fuse_qkv: false
|
61 |
+
flash_attn_fuse_mlp: true
|
62 |
+
|
63 |
+
warmup_steps: 100
|
64 |
+
evals_per_epoch: 4
|
65 |
+
eval_table_size:
|
66 |
+
saves_per_epoch: 1
|
67 |
+
debug:
|
68 |
+
deepspeed:
|
69 |
+
weight_decay: 0.1
|
70 |
+
fsdp:
|
71 |
+
fsdp_config:
|
72 |
+
special_tokens:
|
73 |
+
bos_token: "<s>"
|
74 |
+
eos_token: "</s>"
|
75 |
+
unk_token: "<unk>"
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -45,6 +45,7 @@ from axolotl.utils.callbacks import (
|
|
45 |
causal_lm_bench_eval_callback_factory,
|
46 |
log_prediction_callback_factory,
|
47 |
)
|
|
|
48 |
from axolotl.utils.collators import (
|
49 |
BatchSamplerDataCollatorForSeq2Seq,
|
50 |
DataCollatorForSeq2Seq,
|
@@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
200 |
orpo_alpha: Optional[float] = field(
|
201 |
default=None,
|
202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
|
205 |
class AxolotlTrainer(Trainer):
|
@@ -938,6 +951,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
938 |
)
|
939 |
callbacks.append(early_stop_cb)
|
940 |
|
|
|
|
|
941 |
return callbacks
|
942 |
|
943 |
def _get_trainer_cls(self):
|
@@ -1229,6 +1244,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
1229 |
"relora_prune_ratio"
|
1230 |
] = self.cfg.relora_prune_ratio
|
1231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1232 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
1233 |
training_arguments_kwargs
|
1234 |
)
|
|
|
45 |
causal_lm_bench_eval_callback_factory,
|
46 |
log_prediction_callback_factory,
|
47 |
)
|
48 |
+
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
49 |
from axolotl.utils.collators import (
|
50 |
BatchSamplerDataCollatorForSeq2Seq,
|
51 |
DataCollatorForSeq2Seq,
|
|
|
201 |
orpo_alpha: Optional[float] = field(
|
202 |
default=None,
|
203 |
)
|
204 |
+
lisa_n_layers: Optional[int] = field(
|
205 |
+
default=None,
|
206 |
+
metadata={"help": "the number of activate layers in LISA"},
|
207 |
+
)
|
208 |
+
lisa_step_interval: Optional[int] = field(
|
209 |
+
default=None,
|
210 |
+
metadata={"help": "how often to switch layers in LISA"},
|
211 |
+
)
|
212 |
+
lisa_layers_attribute: Optional[str] = field(
|
213 |
+
default=None,
|
214 |
+
metadata={"help": "path under the model to access the layers"},
|
215 |
+
)
|
216 |
|
217 |
|
218 |
class AxolotlTrainer(Trainer):
|
|
|
951 |
)
|
952 |
callbacks.append(early_stop_cb)
|
953 |
|
954 |
+
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
955 |
+
callbacks.append(lisa_callback_factory(trainer))
|
956 |
return callbacks
|
957 |
|
958 |
def _get_trainer_cls(self):
|
|
|
1244 |
"relora_prune_ratio"
|
1245 |
] = self.cfg.relora_prune_ratio
|
1246 |
|
1247 |
+
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
1248 |
+
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
1249 |
+
training_arguments_kwargs[
|
1250 |
+
"lisa_step_interval"
|
1251 |
+
] = self.cfg.lisa_step_interval
|
1252 |
+
training_arguments_kwargs[
|
1253 |
+
"lisa_layers_attribute"
|
1254 |
+
] = self.cfg.lisa_layers_attribute
|
1255 |
+
|
1256 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
1257 |
training_arguments_kwargs
|
1258 |
)
|
src/axolotl/utils/callbacks/lisa.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
module for LISA
|
3 |
+
|
4 |
+
Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
|
5 |
+
Arxiv: https://arxiv.org/abs/2403.17919
|
6 |
+
License: Apache 2.0
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from functools import reduce
|
11 |
+
from typing import TYPE_CHECKING
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from transformers import TrainerCallback
|
15 |
+
|
16 |
+
if TYPE_CHECKING:
|
17 |
+
from axolotl.core.trainer_builder import AxolotlTrainer
|
18 |
+
|
19 |
+
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
20 |
+
|
21 |
+
|
22 |
+
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
23 |
+
class LISACallback(TrainerCallback):
|
24 |
+
"""trainer callback for lisa layer switching"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, n_layers, step_interval, trainer, layers_attribute="model.layers"
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.n_layers = n_layers
|
31 |
+
self.step_interval = step_interval
|
32 |
+
self.layers_attribute = layers_attribute
|
33 |
+
self.trainer = trainer
|
34 |
+
|
35 |
+
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
36 |
+
|
37 |
+
self.total_layers = len(
|
38 |
+
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
39 |
+
)
|
40 |
+
self.active_layers_indices = []
|
41 |
+
|
42 |
+
layers = reduce(
|
43 |
+
getattr, self.layers_attribute.split("."), self.trainer.model
|
44 |
+
)
|
45 |
+
LOG.info(
|
46 |
+
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
47 |
+
)
|
48 |
+
|
49 |
+
def freeze_all_layers(self):
|
50 |
+
layers = reduce(
|
51 |
+
getattr, self.layers_attribute.split("."), self.trainer.model
|
52 |
+
)
|
53 |
+
for layer in layers:
|
54 |
+
for param in layer.parameters():
|
55 |
+
param.requires_grad = False
|
56 |
+
|
57 |
+
def on_step_begin(
|
58 |
+
self, args, state, control, **kwargs
|
59 |
+
): # pylint: disable=unused-argument
|
60 |
+
# Check if it's time to switch active layers, including at step 0
|
61 |
+
if state.global_step % self.step_interval == 0 or state.global_step == 1:
|
62 |
+
self.switch_active_layers()
|
63 |
+
|
64 |
+
def switch_active_layers(self):
|
65 |
+
# First, disable gradients for all layers
|
66 |
+
self.freeze_all_layers()
|
67 |
+
|
68 |
+
# Randomly select n_layers to activate
|
69 |
+
layers = reduce(
|
70 |
+
getattr, self.layers_attribute.split("."), self.trainer.model
|
71 |
+
)
|
72 |
+
self.active_layers_indices = np.random.choice(
|
73 |
+
range(self.total_layers), self.n_layers, replace=False
|
74 |
+
)
|
75 |
+
LOG.info(
|
76 |
+
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
77 |
+
)
|
78 |
+
|
79 |
+
# Enable gradients only for the selected layers
|
80 |
+
for idx in self.active_layers_indices:
|
81 |
+
for param in layers[idx].parameters():
|
82 |
+
param.requires_grad = True
|
83 |
+
|
84 |
+
lisa_callback = LISACallback(
|
85 |
+
n_layers=trainer.args.lisa_n_layers,
|
86 |
+
step_interval=trainer.args.lisa_step_interval,
|
87 |
+
trainer=trainer,
|
88 |
+
layers_attribute=trainer.args.lisa_layers_attribute,
|
89 |
+
)
|
90 |
+
|
91 |
+
return lisa_callback
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel):
|
|
370 |
hf_mlflow_log_artifacts: Optional[bool] = None
|
371 |
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
class WandbConfig(BaseModel):
|
374 |
"""wandb configuration subset"""
|
375 |
|
@@ -404,6 +421,7 @@ class AxolotlInputConfig(
|
|
404 |
HyperparametersConfig,
|
405 |
WandbConfig,
|
406 |
MLFlowConfig,
|
|
|
407 |
RemappedParameters,
|
408 |
DeprecatedParameters,
|
409 |
BaseModel,
|
|
|
370 |
hf_mlflow_log_artifacts: Optional[bool] = None
|
371 |
|
372 |
|
373 |
+
class LISAConfig(BaseModel):
|
374 |
+
"""LISA options"""
|
375 |
+
|
376 |
+
lisa_n_layers: Optional[int] = Field(
|
377 |
+
default=None,
|
378 |
+
metadata={"help": "the number of activate layers in LISA"},
|
379 |
+
)
|
380 |
+
lisa_step_interval: Optional[int] = Field(
|
381 |
+
default=None,
|
382 |
+
metadata={"help": "how often to switch layers in LISA"},
|
383 |
+
)
|
384 |
+
lisa_layers_attribute: Optional[str] = Field(
|
385 |
+
default="model.layers",
|
386 |
+
metadata={"help": "path under the model to access the layers"},
|
387 |
+
)
|
388 |
+
|
389 |
+
|
390 |
class WandbConfig(BaseModel):
|
391 |
"""wandb configuration subset"""
|
392 |
|
|
|
421 |
HyperparametersConfig,
|
422 |
WandbConfig,
|
423 |
MLFlowConfig,
|
424 |
+
LISAConfig,
|
425 |
RemappedParameters,
|
426 |
DeprecatedParameters,
|
427 |
BaseModel,
|