Add callback save peft_model on_save
Browse files
src/axolotl/utils/callbacks.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
4 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
5 |
+
|
6 |
+
class SavePeftModelCallback(TrainerCallback):
|
7 |
+
def on_save(
|
8 |
+
self,
|
9 |
+
args: TrainingArguments,
|
10 |
+
state: TrainerState,
|
11 |
+
control: TrainerControl,
|
12 |
+
**kwargs,
|
13 |
+
):
|
14 |
+
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
15 |
+
|
16 |
+
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
17 |
+
kwargs["model"].save_pretrained(peft_model_path)
|
18 |
+
|
19 |
+
return control
|
src/axolotl/utils/trainer.py
CHANGED
@@ -13,6 +13,7 @@ from transformers import EarlyStoppingCallback
|
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
|
|
16 |
|
17 |
|
18 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
@@ -188,6 +189,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
188 |
data_collator_kwargs["padding"] = "longest"
|
189 |
else:
|
190 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
|
|
|
|
|
|
|
|
|
|
191 |
trainer = transformers.Trainer(
|
192 |
model=model,
|
193 |
train_dataset=train_dataset,
|
|
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
16 |
+
from axolotl.utils.callbacks import SavePeftModelCallback
|
17 |
|
18 |
|
19 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
|
189 |
data_collator_kwargs["padding"] = "longest"
|
190 |
else:
|
191 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
192 |
+
|
193 |
+
callbacks = []
|
194 |
+
if cfg.adapter == 'lora':
|
195 |
+
callbacks.append(SavePeftModelCallback)
|
196 |
+
|
197 |
trainer = transformers.Trainer(
|
198 |
model=model,
|
199 |
train_dataset=train_dataset,
|