support for multi line inference input, log sweep over learning rates
Browse files- scripts/finetune.py +17 -10
- src/axolotl/utils/schedulers.py +34 -0
- src/axolotl/utils/trainer.py +17 -2
scripts/finetune.py
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
import logging
|
2 |
import os
|
|
|
3 |
import random
|
4 |
import signal
|
5 |
import sys
|
@@ -44,18 +46,20 @@ def choose_device(cfg):
|
|
44 |
cfg.device_map = {"": cfg.device}
|
45 |
|
46 |
|
47 |
-
def do_inference(cfg, model, tokenizer):
|
48 |
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
49 |
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
50 |
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
51 |
|
52 |
-
|
53 |
|
54 |
while True:
|
55 |
-
|
|
|
|
|
56 |
if not instruction:
|
57 |
return
|
58 |
-
prompt =
|
59 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
60 |
|
61 |
model.eval()
|
@@ -162,6 +166,10 @@ def train(
|
|
162 |
do_inference(cfg, model, tokenizer)
|
163 |
return
|
164 |
|
|
|
|
|
|
|
|
|
165 |
train_dataset, eval_dataset = load_prepare_datasets(
|
166 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
167 |
)
|
@@ -207,12 +215,11 @@ def train(
|
|
207 |
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
|
208 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
model.save_pretrained(cfg.output_dir)
|
216 |
|
217 |
|
218 |
if __name__ == "__main__":
|
|
|
1 |
+
import importlib
|
2 |
import logging
|
3 |
import os
|
4 |
+
import pathlib
|
5 |
import random
|
6 |
import signal
|
7 |
import sys
|
|
|
46 |
cfg.device_map = {"": cfg.device}
|
47 |
|
48 |
|
49 |
+
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
50 |
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
51 |
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
52 |
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
53 |
|
54 |
+
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
|
55 |
|
56 |
while True:
|
57 |
+
# support for multiline inputs
|
58 |
+
print("Give me an instruction (Ctrl + D to finish): ")
|
59 |
+
instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
60 |
if not instruction:
|
61 |
return
|
62 |
+
prompt = prompter_module().build_prompt(instruction=instruction)
|
63 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
64 |
|
65 |
model.eval()
|
|
|
166 |
do_inference(cfg, model, tokenizer)
|
167 |
return
|
168 |
|
169 |
+
if "shard" in kwargs:
|
170 |
+
model.save_pretrained(cfg.output_dir)
|
171 |
+
return
|
172 |
+
|
173 |
train_dataset, eval_dataset = load_prepare_datasets(
|
174 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
175 |
)
|
|
|
215 |
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
|
216 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
217 |
|
218 |
+
logging.info(
|
219 |
+
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
|
220 |
+
)
|
221 |
+
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
222 |
+
trainer.save_model(cfg.output_dir)
|
|
|
223 |
|
224 |
|
225 |
if __name__ == "__main__":
|
src/axolotl/utils/schedulers.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.optim.lr_scheduler import LRScheduler
|
2 |
+
|
3 |
+
|
4 |
+
class InterpolatingLogScheduler(LRScheduler):
|
5 |
+
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
|
6 |
+
"""A scheduler that interpolates learning rates in a logarithmic fashion
|
7 |
+
|
8 |
+
Args:
|
9 |
+
- optimizer: pytorch optimizer
|
10 |
+
- num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
|
11 |
+
- min_lr: float, the minimum learning rate
|
12 |
+
- max_lr: float, the maximum learning rate
|
13 |
+
|
14 |
+
Usage:
|
15 |
+
fc = nn.Linear(1,1)
|
16 |
+
optimizer = optim.Adam(fc.parameters())
|
17 |
+
lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
|
18 |
+
"""
|
19 |
+
self.num_steps = num_steps
|
20 |
+
self.min_lr = min_lr
|
21 |
+
self.max_lr = max_lr
|
22 |
+
self.q = (max_lr / min_lr) ** (1 / num_steps - 1)
|
23 |
+
super().__init__(optimizer, last_epoch)
|
24 |
+
|
25 |
+
def get_lr(self):
|
26 |
+
if self.last_epoch == 0:
|
27 |
+
lr = self.min_lr
|
28 |
+
elif self.last_epoch < self.num_steps:
|
29 |
+
# FIXME, not perfect as we need to account for number of steps are in an epoch, etc
|
30 |
+
lr = self.min_lr * (self.q ** self.last_epoch)
|
31 |
+
else:
|
32 |
+
lr = self.max_lr
|
33 |
+
|
34 |
+
return [lr for _ in self.base_lrs]
|
src/axolotl/utils/trainer.py
CHANGED
@@ -12,6 +12,8 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|
12 |
from transformers import EarlyStoppingCallback
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
|
|
|
|
15 |
|
16 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
17 |
total_num_steps = int(
|
@@ -27,11 +29,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
27 |
if cfg.logging_steps is not None
|
28 |
else max(min(int(0.005 * total_num_steps), 10), 1)
|
29 |
)
|
30 |
-
save_steps =
|
31 |
cfg.save_steps
|
32 |
if cfg.save_steps is not None
|
33 |
else min(int(0.05 * total_num_steps), 200)
|
34 |
)
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
training_arguments_kwargs = {}
|
37 |
if cfg.bf16 == "full":
|
@@ -95,7 +102,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
95 |
report_to="wandb" if cfg.use_wandb else None,
|
96 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
97 |
optim=cfg.optimizer if cfg.optimizer else None,
|
98 |
-
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else
|
99 |
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
100 |
**training_arguments_kwargs,
|
101 |
)
|
@@ -147,8 +154,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
147 |
optimizer,
|
148 |
cfg.learning_rate,
|
149 |
total_steps=total_num_steps,
|
|
|
150 |
**lr_scheduler_kwargs,
|
151 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
else:
|
153 |
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
154 |
optimizer,
|
|
|
12 |
from transformers import EarlyStoppingCallback
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
15 |
+
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
16 |
+
|
17 |
|
18 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
19 |
total_num_steps = int(
|
|
|
29 |
if cfg.logging_steps is not None
|
30 |
else max(min(int(0.005 * total_num_steps), 10), 1)
|
31 |
)
|
32 |
+
save_steps = (
|
33 |
cfg.save_steps
|
34 |
if cfg.save_steps is not None
|
35 |
else min(int(0.05 * total_num_steps), 200)
|
36 |
)
|
37 |
+
eval_steps = (
|
38 |
+
cfg.eval_steps
|
39 |
+
if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
|
40 |
+
else save_steps
|
41 |
+
)
|
42 |
|
43 |
training_arguments_kwargs = {}
|
44 |
if cfg.bf16 == "full":
|
|
|
102 |
report_to="wandb" if cfg.use_wandb else None,
|
103 |
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
104 |
optim=cfg.optimizer if cfg.optimizer else None,
|
105 |
+
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
|
106 |
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
107 |
**training_arguments_kwargs,
|
108 |
)
|
|
|
154 |
optimizer,
|
155 |
cfg.learning_rate,
|
156 |
total_steps=total_num_steps,
|
157 |
+
epochs=cfg.num_epochs,
|
158 |
**lr_scheduler_kwargs,
|
159 |
)
|
160 |
+
elif cfg.lr_scheduler == "log_sweep":
|
161 |
+
lr_scheduler = InterpolatingLogScheduler(
|
162 |
+
optimizer,
|
163 |
+
cfg.warmup_steps,
|
164 |
+
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
|
165 |
+
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
|
166 |
+
)
|
167 |
else:
|
168 |
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
169 |
optimizer,
|