WIP large refactor to make finetune script a little more manageable (#3)
Browse files- configs/gpt_neox_20b.yml +39 -0
- scripts/finetune.py +18 -398
- src/axolotl/prompt_tokenizers.py +10 -0
- src/axolotl/prompters.py +4 -0
- src/axolotl/utils/__init__.py +0 -0
- src/axolotl/utils/data.py +125 -0
- src/axolotl/utils/models.py +206 -0
- src/axolotl/utils/trainer.py +109 -0
- src/axolotl/utils/wandb.py +13 -0
configs/gpt_neox_20b.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: EleutherAI/gpt-neox-20b
|
2 |
+
base_model_ignore_patterns: pytorch* # prefer safetensors
|
3 |
+
model_type: GPTNeoXForCausalLM
|
4 |
+
tokenizer_type: AutoTokenizer
|
5 |
+
load_in_8bit: true
|
6 |
+
datasets:
|
7 |
+
- path: nomic-ai/gpt4all-j-prompt-generations
|
8 |
+
type: alpaca
|
9 |
+
shards: 4
|
10 |
+
shards_index: 0
|
11 |
+
dataset_prepared_path: last_run_prepared
|
12 |
+
val_set_size: 0.05
|
13 |
+
adapter: lora
|
14 |
+
lora_model_dir:
|
15 |
+
sequence_len: 2048
|
16 |
+
max_packed_sequence_len: 2048
|
17 |
+
lora_r: 8
|
18 |
+
lora_alpha: 32
|
19 |
+
lora_dropout: 0.05
|
20 |
+
lora_target_modules:
|
21 |
+
- query_key_value
|
22 |
+
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
23 |
+
wandb_project: gpt4all-neox-20b
|
24 |
+
wandb_watch:
|
25 |
+
wandb_run_id:
|
26 |
+
wandb_log_model: checkpoint
|
27 |
+
output_dir: ./gpt4all-neox-20b
|
28 |
+
batch_size: 48
|
29 |
+
micro_batch_size: 4
|
30 |
+
num_epochs: 5
|
31 |
+
learning_rate: 0.00003
|
32 |
+
lr_scheduler: one_cycle
|
33 |
+
train_on_inputs: false
|
34 |
+
group_by_length: false
|
35 |
+
bf16: True
|
36 |
+
tf32: True
|
37 |
+
early_stopping_patience:
|
38 |
+
resume_from_checkpoint:
|
39 |
+
local_rank:
|
scripts/finetune.py
CHANGED
@@ -1,223 +1,29 @@
|
|
1 |
import logging
|
2 |
-
import math
|
3 |
import os
|
4 |
import random
|
5 |
import signal
|
6 |
import sys
|
7 |
-
from hashlib import md5
|
8 |
from pathlib import Path
|
9 |
|
10 |
-
import bitsandbytes as bnb
|
11 |
import fire
|
12 |
import torch
|
13 |
-
import transformers
|
14 |
import yaml
|
15 |
from attrdict import AttrDefault
|
16 |
-
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
17 |
-
from torch import nn
|
18 |
-
from transformers import (
|
19 |
-
AutoModelForCausalLM,
|
20 |
-
AutoTokenizer,
|
21 |
-
LlamaForCausalLM,
|
22 |
-
LlamaTokenizer,
|
23 |
-
EarlyStoppingCallback,
|
24 |
-
GenerationConfig,
|
25 |
-
)
|
26 |
|
27 |
# add src to the pythonpath so we don't need to pip install this
|
28 |
-
from transformers.trainer_pt_utils import get_parameter_names
|
29 |
-
|
30 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
31 |
src_dir = os.path.join(project_root, "src")
|
32 |
sys.path.insert(0, src_dir)
|
33 |
|
34 |
-
from axolotl.
|
35 |
-
from axolotl.
|
36 |
-
|
37 |
-
|
38 |
-
LLAMA_DEFAULT_PAD_TOKEN,
|
39 |
-
GPTeacherPromptTokenizingStrategy,
|
40 |
-
OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy,
|
41 |
-
)
|
42 |
-
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter
|
43 |
|
44 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
45 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
46 |
|
47 |
|
48 |
-
def setup_wandb_env_vars(cfg):
|
49 |
-
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
50 |
-
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
51 |
-
cfg.use_wandb = True
|
52 |
-
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
53 |
-
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
54 |
-
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
55 |
-
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
56 |
-
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
57 |
-
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
58 |
-
|
59 |
-
|
60 |
-
def load_model(
|
61 |
-
base_model,
|
62 |
-
base_model_config,
|
63 |
-
model_type,
|
64 |
-
tokenizer_type,
|
65 |
-
cfg,
|
66 |
-
adapter="lora",
|
67 |
-
inference: bool = False,
|
68 |
-
):
|
69 |
-
# TODO refactor as a kwarg
|
70 |
-
load_in_8bit = cfg.load_in_8bit
|
71 |
-
tokenizer = None
|
72 |
-
is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
|
73 |
-
|
74 |
-
if adapter != "lora":
|
75 |
-
raise NotImplementedError(f"{adapter} peft adapter not available")
|
76 |
-
if is_llama_derived_model and cfg.flash_attention:
|
77 |
-
if cfg.device not in ["mps", "cpu"] and inference is False:
|
78 |
-
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
79 |
-
|
80 |
-
logging.info("patching with flash attention")
|
81 |
-
replace_llama_attn_with_flash_attn()
|
82 |
-
|
83 |
-
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
84 |
-
try:
|
85 |
-
if cfg.load_4bit:
|
86 |
-
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
87 |
-
replace_peft_model_with_int4_lora_model,
|
88 |
-
)
|
89 |
-
|
90 |
-
replace_peft_model_with_int4_lora_model()
|
91 |
-
|
92 |
-
from peft import (
|
93 |
-
LoraConfig,
|
94 |
-
get_peft_model,
|
95 |
-
prepare_model_for_int8_training,
|
96 |
-
PeftModel,
|
97 |
-
)
|
98 |
-
except Exception as e:
|
99 |
-
logging.exception(e)
|
100 |
-
raise e
|
101 |
-
|
102 |
-
try:
|
103 |
-
if cfg.load_4bit and is_llama_derived_model:
|
104 |
-
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
105 |
-
from huggingface_hub import snapshot_download
|
106 |
-
|
107 |
-
cache_model_path = Path(snapshot_download(base_model))
|
108 |
-
files = (
|
109 |
-
list(cache_model_path.glob("*.pt"))
|
110 |
-
+ list(cache_model_path.glob("*.safetensors"))
|
111 |
-
+ list(cache_model_path.glob("*.bin"))
|
112 |
-
)
|
113 |
-
if len(files) > 0:
|
114 |
-
model_path = str(files[0])
|
115 |
-
else:
|
116 |
-
logging.warning(
|
117 |
-
"unable to find a cached model file, this will likely fail..."
|
118 |
-
)
|
119 |
-
model_path = str(cache_model_path)
|
120 |
-
model, tokenizer = load_llama_model_4bit_low_ram(
|
121 |
-
base_model_config if base_model_config else base_model,
|
122 |
-
model_path,
|
123 |
-
device_map=cfg.device_map,
|
124 |
-
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
|
125 |
-
is_v1_model=cfg.gptq_model_v1
|
126 |
-
if cfg.gptq_model_v1 is not None
|
127 |
-
else True,
|
128 |
-
)
|
129 |
-
load_in_8bit = False
|
130 |
-
elif is_llama_derived_model:
|
131 |
-
model = LlamaForCausalLM.from_pretrained(
|
132 |
-
base_model,
|
133 |
-
load_in_8bit=cfg.load_in_8bit,
|
134 |
-
torch_dtype=torch_dtype,
|
135 |
-
device_map=cfg.device_map,
|
136 |
-
)
|
137 |
-
else:
|
138 |
-
model = getattr(transformers, model_type).from_pretrained(
|
139 |
-
base_model,
|
140 |
-
load_in_8bit=cfg.load_in_8bit,
|
141 |
-
torch_dtype=torch_dtype,
|
142 |
-
device_map=cfg.device_map,
|
143 |
-
)
|
144 |
-
except Exception as e:
|
145 |
-
logging.error(
|
146 |
-
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
147 |
-
)
|
148 |
-
logging.exception(e)
|
149 |
-
model = AutoModelForCausalLM.from_pretrained(
|
150 |
-
base_model,
|
151 |
-
load_in_8bit=cfg.load_in_8bit,
|
152 |
-
torch_dtype=torch_dtype,
|
153 |
-
device_map=cfg.device_map,
|
154 |
-
)
|
155 |
-
|
156 |
-
if not tokenizer:
|
157 |
-
try:
|
158 |
-
if is_llama_derived_model:
|
159 |
-
tokenizer = LlamaTokenizer.from_pretrained(model)
|
160 |
-
else:
|
161 |
-
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
162 |
-
except:
|
163 |
-
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
164 |
-
|
165 |
-
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
166 |
-
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
167 |
-
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
168 |
-
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
169 |
-
|
170 |
-
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
171 |
-
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
172 |
-
|
173 |
-
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
174 |
-
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
175 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
176 |
-
|
177 |
-
if load_in_8bit and not cfg.load_4bit:
|
178 |
-
logging.info("converting model w/ prepare_model_for_int8_training")
|
179 |
-
model = prepare_model_for_int8_training(model)
|
180 |
-
|
181 |
-
lora_config = LoraConfig(
|
182 |
-
r=cfg.lora_r,
|
183 |
-
lora_alpha=cfg.lora_alpha,
|
184 |
-
target_modules=cfg.lora_target_modules,
|
185 |
-
lora_dropout=cfg.lora_dropout,
|
186 |
-
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
187 |
-
bias="none",
|
188 |
-
task_type="CAUSAL_LM",
|
189 |
-
)
|
190 |
-
|
191 |
-
if cfg.lora_model_dir:
|
192 |
-
model = PeftModel.from_pretrained(
|
193 |
-
model,
|
194 |
-
cfg.lora_model_dir,
|
195 |
-
device_map=cfg.device_map,
|
196 |
-
torch_dtype=torch.float16,
|
197 |
-
)
|
198 |
-
else:
|
199 |
-
model = get_peft_model(model, lora_config)
|
200 |
-
|
201 |
-
if cfg.ddp:
|
202 |
-
model.to(f"cuda:{cfg.local_rank}")
|
203 |
-
|
204 |
-
if cfg.load_4bit:
|
205 |
-
# Scales to half
|
206 |
-
logging.info("Fitting 4bit scales and zeros to half")
|
207 |
-
for n, m in model.named_modules():
|
208 |
-
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
209 |
-
type(m)
|
210 |
-
):
|
211 |
-
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
212 |
-
m.zeros = m.zeros.half()
|
213 |
-
m.scales = m.scales.half()
|
214 |
-
m.bias = m.bias.half()
|
215 |
-
|
216 |
-
# TODO resume_from_checkpoint handling
|
217 |
-
model.print_trainable_parameters()
|
218 |
-
return model, tokenizer, lora_config
|
219 |
-
|
220 |
-
|
221 |
def choose_device(cfg):
|
222 |
def get_device():
|
223 |
if torch.cuda.is_available():
|
@@ -271,11 +77,13 @@ def do_inference(cfg, model, tokenizer):
|
|
271 |
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
272 |
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
|
|
278 |
)
|
|
|
279 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
280 |
|
281 |
model.eval()
|
@@ -324,98 +132,6 @@ def choose_config(path: Path):
|
|
324 |
return chosen_file
|
325 |
|
326 |
|
327 |
-
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
328 |
-
total_num_steps = int(
|
329 |
-
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
330 |
-
)
|
331 |
-
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
332 |
-
logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
|
333 |
-
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
334 |
-
|
335 |
-
training_arguments_kwargs = {}
|
336 |
-
if cfg.bf16 == "full":
|
337 |
-
training_arguments_kwargs["bf16_full_eval"] = True
|
338 |
-
else:
|
339 |
-
training_arguments_kwargs["bf16"] = cfg.bf16
|
340 |
-
training_arguments_kwargs["tf32"] = cfg.tf32
|
341 |
-
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
342 |
-
training_arguments_kwargs["logging_steps"] = logging_steps
|
343 |
-
if cfg.gradient_checkpointing is not None:
|
344 |
-
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
345 |
-
|
346 |
-
training_args = transformers.TrainingArguments(
|
347 |
-
per_device_train_batch_size=cfg.micro_batch_size,
|
348 |
-
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
349 |
-
num_train_epochs=cfg.num_epochs,
|
350 |
-
learning_rate=cfg.learning_rate,
|
351 |
-
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
352 |
-
save_strategy="steps",
|
353 |
-
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
354 |
-
save_steps=save_steps,
|
355 |
-
output_dir=cfg.output_dir,
|
356 |
-
save_total_limit=3,
|
357 |
-
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
|
358 |
-
ddp_find_unused_parameters=False if cfg.ddp else None,
|
359 |
-
group_by_length=cfg.group_by_length,
|
360 |
-
report_to="wandb" if cfg.use_wandb else None,
|
361 |
-
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
362 |
-
**training_arguments_kwargs,
|
363 |
-
)
|
364 |
-
|
365 |
-
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
366 |
-
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
367 |
-
optimizer_grouped_parameters = [
|
368 |
-
{
|
369 |
-
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
370 |
-
"weight_decay": training_args.weight_decay,
|
371 |
-
},
|
372 |
-
{
|
373 |
-
"params": [
|
374 |
-
p for n, p in model.named_parameters() if n not in decay_parameters
|
375 |
-
],
|
376 |
-
"weight_decay": 0.0,
|
377 |
-
},
|
378 |
-
]
|
379 |
-
|
380 |
-
trainer_kwargs = {}
|
381 |
-
|
382 |
-
if cfg.load_in_8bit and not cfg.load_4bit:
|
383 |
-
adam_bnb_optim = bnb.optim.Adam8bit(
|
384 |
-
optimizer_grouped_parameters,
|
385 |
-
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
386 |
-
eps=training_args.adam_epsilon,
|
387 |
-
lr=training_args.learning_rate,
|
388 |
-
)
|
389 |
-
|
390 |
-
# TODO optionally use torch.optim.OneCycleLR
|
391 |
-
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
392 |
-
adam_bnb_optim,
|
393 |
-
training_args.warmup_steps,
|
394 |
-
total_num_steps,
|
395 |
-
)
|
396 |
-
trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
|
397 |
-
|
398 |
-
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
399 |
-
if cfg.early_stopping_patience:
|
400 |
-
early_stop_cb = EarlyStoppingCallback(
|
401 |
-
cfg.early_stopping_patience,
|
402 |
-
)
|
403 |
-
trainer_kwargs["callbacks"] = [early_stop_cb]
|
404 |
-
|
405 |
-
trainer = transformers.Trainer(
|
406 |
-
model=model,
|
407 |
-
train_dataset=train_dataset,
|
408 |
-
eval_dataset=eval_dataset,
|
409 |
-
args=training_args,
|
410 |
-
data_collator=transformers.DataCollatorForSeq2Seq(
|
411 |
-
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
412 |
-
),
|
413 |
-
**trainer_kwargs,
|
414 |
-
)
|
415 |
-
|
416 |
-
return trainer
|
417 |
-
|
418 |
-
|
419 |
def train(
|
420 |
config: Path = Path("configs/"),
|
421 |
prepare_ds_only: bool = False,
|
@@ -474,110 +190,13 @@ def train(
|
|
474 |
do_inference(cfg, model, tokenizer)
|
475 |
return
|
476 |
|
477 |
-
|
478 |
-
|
479 |
-
)
|
480 |
-
max_packed_sequence_len = min(
|
481 |
-
max_packed_sequence_len, cfg.sequence_len
|
482 |
-
) # make sure we don't accidentally set it larger than sequence_len
|
483 |
-
ds_hash = str(
|
484 |
-
md5(
|
485 |
-
(
|
486 |
-
str(max_packed_sequence_len)
|
487 |
-
+ "@"
|
488 |
-
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
489 |
-
).encode("utf-8")
|
490 |
-
).hexdigest()
|
491 |
-
)
|
492 |
-
prepared_ds_path = (
|
493 |
-
Path(cfg.dataset_prepared_path) / ds_hash
|
494 |
-
if cfg.dataset_prepared_path
|
495 |
-
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
496 |
)
|
497 |
|
498 |
-
if
|
499 |
-
logging.info("
|
500 |
-
|
501 |
-
logging.info("Prepared dataset loaded from disk...")
|
502 |
-
else:
|
503 |
-
logging.info("Loading raw datasets...")
|
504 |
-
datasets = []
|
505 |
-
for d in cfg.datasets:
|
506 |
-
ds_from_hub = False
|
507 |
-
try:
|
508 |
-
load_dataset(d.path, streaming=True)
|
509 |
-
ds_from_hub = True
|
510 |
-
except FileNotFoundError:
|
511 |
-
pass
|
512 |
-
|
513 |
-
# prefer local dataset, even if hub exists
|
514 |
-
if Path(d.path).exists():
|
515 |
-
ds: IterableDataset = load_dataset(
|
516 |
-
"json", data_files=d.path, streaming=True, split=None
|
517 |
-
)
|
518 |
-
elif ds_from_hub:
|
519 |
-
ds = load_dataset(d.path, streaming=True)
|
520 |
-
else:
|
521 |
-
raise Exception("unhandled dataset load")
|
522 |
-
|
523 |
-
if d.type == "alpaca":
|
524 |
-
ds_strategy = AlpacaPromptTokenizingStrategy(
|
525 |
-
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
526 |
-
)
|
527 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
528 |
-
datasets.append(ds_wrapper)
|
529 |
-
elif d.type == "oasst":
|
530 |
-
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
531 |
-
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
532 |
-
)
|
533 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
534 |
-
datasets.append(ds_wrapper)
|
535 |
-
elif d.type == "gpteacher":
|
536 |
-
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
537 |
-
GPTeacherPrompter(),
|
538 |
-
tokenizer,
|
539 |
-
cfg.train_on_inputs,
|
540 |
-
cfg.sequence_len,
|
541 |
-
)
|
542 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
543 |
-
datasets.append(ds_wrapper)
|
544 |
-
elif d.type == "reflection":
|
545 |
-
ds_strategy = AlpacaReflectionPTStrategy(
|
546 |
-
ReflectAlpacaPrompter(),
|
547 |
-
tokenizer,
|
548 |
-
cfg.train_on_inputs,
|
549 |
-
cfg.sequence_len,
|
550 |
-
)
|
551 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
552 |
-
datasets.append(ds_wrapper)
|
553 |
-
elif d.type == "sharegpt":
|
554 |
-
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
555 |
-
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
556 |
-
)
|
557 |
-
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
558 |
-
datasets.append(ds_wrapper)
|
559 |
-
else:
|
560 |
-
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
561 |
-
constant_len_dataset = ConstantLengthDataset(
|
562 |
-
tokenizer,
|
563 |
-
datasets,
|
564 |
-
seq_length=max_packed_sequence_len,
|
565 |
-
)
|
566 |
-
logging.info("merging, packing, shuffling, and splitting master dataset")
|
567 |
-
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
568 |
-
test_size=cfg.val_set_size, shuffle=True, seed=42
|
569 |
-
)
|
570 |
-
|
571 |
-
if cfg.local_rank == 0:
|
572 |
-
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
|
573 |
-
dataset.save_to_disk(prepared_ds_path)
|
574 |
-
|
575 |
-
if prepare_ds_only:
|
576 |
-
logging.info("Finished preparing dataset. Exiting...")
|
577 |
-
return
|
578 |
-
|
579 |
-
train_dataset = dataset["train"]
|
580 |
-
eval_dataset = dataset["test"]
|
581 |
|
582 |
if cfg.debug:
|
583 |
check_dataset_labels(
|
@@ -594,8 +213,9 @@ def train(
|
|
594 |
model = torch.compile(model)
|
595 |
|
596 |
# go ahead and presave, so we have the adapter config available to inspect
|
597 |
-
|
598 |
-
|
|
|
599 |
|
600 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
601 |
if cfg.local_rank == 0:
|
|
|
1 |
import logging
|
|
|
2 |
import os
|
3 |
import random
|
4 |
import signal
|
5 |
import sys
|
|
|
6 |
from pathlib import Path
|
7 |
|
|
|
8 |
import fire
|
9 |
import torch
|
|
|
10 |
import yaml
|
11 |
from attrdict import AttrDefault
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# add src to the pythonpath so we don't need to pip install this
|
|
|
|
|
14 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
15 |
src_dir = os.path.join(project_root, "src")
|
16 |
sys.path.insert(0, src_dir)
|
17 |
|
18 |
+
from axolotl.utils.data import load_prepare_datasets
|
19 |
+
from axolotl.utils.models import load_model
|
20 |
+
from axolotl.utils.trainer import setup_trainer
|
21 |
+
from axolotl.utils.wandb import setup_wandb_env_vars
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
24 |
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def choose_device(cfg):
|
28 |
def get_device():
|
29 |
if torch.cuda.is_available():
|
|
|
77 |
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
78 |
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
79 |
|
80 |
+
from axolotl.prompters import ReflectAlpacaPrompter
|
81 |
+
|
82 |
+
instruction = str(input("Give me an instruction: "))
|
83 |
+
instruction = (
|
84 |
+
instruction if not instruction else "Tell me a joke about dromedaries."
|
85 |
)
|
86 |
+
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
|
87 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
88 |
|
89 |
model.eval()
|
|
|
132 |
return chosen_file
|
133 |
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def train(
|
136 |
config: Path = Path("configs/"),
|
137 |
prepare_ds_only: bool = False,
|
|
|
190 |
do_inference(cfg, model, tokenizer)
|
191 |
return
|
192 |
|
193 |
+
train_dataset, eval_dataset = load_prepare_datasets(
|
194 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
)
|
196 |
|
197 |
+
if prepare_ds_only:
|
198 |
+
logging.info("Finished preparing dataset. Exiting...")
|
199 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
if cfg.debug:
|
202 |
check_dataset_labels(
|
|
|
213 |
model = torch.compile(model)
|
214 |
|
215 |
# go ahead and presave, so we have the adapter config available to inspect
|
216 |
+
if lora_config:
|
217 |
+
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
218 |
+
lora_config.save_pretrained(cfg.output_dir)
|
219 |
|
220 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
221 |
if cfg.local_rank == 0:
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -107,6 +107,15 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
107 |
)
|
108 |
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
111 |
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
112 |
raise NotImplementedError
|
@@ -168,6 +177,7 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
|
168 |
prompt["corrected"],
|
169 |
)
|
170 |
|
|
|
171 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
172 |
def tokenize_prompt(self, prompt):
|
173 |
try:
|
|
|
107 |
)
|
108 |
|
109 |
|
110 |
+
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
111 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
112 |
+
return (
|
113 |
+
prompt["prompt"],
|
114 |
+
"",
|
115 |
+
prompt["response"],
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
120 |
def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
|
121 |
raise NotImplementedError
|
|
|
177 |
prompt["corrected"],
|
178 |
)
|
179 |
|
180 |
+
|
181 |
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
182 |
def tokenize_prompt(self, prompt):
|
183 |
try:
|
src/axolotl/prompters.py
CHANGED
@@ -35,6 +35,10 @@ class GPTeacherPrompter(AlpacaPrompter):
|
|
35 |
...
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
38 |
class ReflectAlpacaPrompter:
|
39 |
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
40 |
prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
|
|
|
35 |
...
|
36 |
|
37 |
|
38 |
+
class NomicGPT4AllPrompter(AlpacaPrompter):
|
39 |
+
...
|
40 |
+
|
41 |
+
|
42 |
class ReflectAlpacaPrompter:
|
43 |
prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
44 |
prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
|
src/axolotl/utils/__init__.py
ADDED
File without changes
|
src/axolotl/utils/data.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from hashlib import md5
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
|
6 |
+
|
7 |
+
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
8 |
+
from axolotl.prompt_tokenizers import (
|
9 |
+
AlpacaPromptTokenizingStrategy,
|
10 |
+
GPTeacherPromptTokenizingStrategy,
|
11 |
+
OpenAssistantPromptTokenizingStrategy,
|
12 |
+
AlpacaReflectionPTStrategy,
|
13 |
+
ShareGPTPromptTokenizingStrategy,
|
14 |
+
)
|
15 |
+
from axolotl.prompters import (
|
16 |
+
AlpacaPrompter,
|
17 |
+
GPTeacherPrompter,
|
18 |
+
ReflectAlpacaPrompter,
|
19 |
+
ShareGPTPrompter,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
24 |
+
max_packed_sequence_len = (
|
25 |
+
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
26 |
+
)
|
27 |
+
max_packed_sequence_len = min(
|
28 |
+
max_packed_sequence_len, cfg.sequence_len
|
29 |
+
) # make sure we don't accidentally set it larger than sequence_len
|
30 |
+
ds_hash = str(
|
31 |
+
md5(
|
32 |
+
(
|
33 |
+
str(max_packed_sequence_len)
|
34 |
+
+ "@"
|
35 |
+
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
36 |
+
).encode("utf-8")
|
37 |
+
).hexdigest()
|
38 |
+
)
|
39 |
+
prepared_ds_path = (
|
40 |
+
Path(cfg.dataset_prepared_path) / ds_hash
|
41 |
+
if cfg.dataset_prepared_path
|
42 |
+
else Path(default_dataset_prepared_path) / ds_hash
|
43 |
+
)
|
44 |
+
|
45 |
+
if any(prepared_ds_path.glob("*")):
|
46 |
+
logging.info("Loading prepared dataset from disk...")
|
47 |
+
dataset = load_from_disk(str(prepared_ds_path))
|
48 |
+
logging.info("Prepared dataset loaded from disk...")
|
49 |
+
else:
|
50 |
+
logging.info("Loading raw datasets...")
|
51 |
+
datasets = []
|
52 |
+
for d in cfg.datasets:
|
53 |
+
ds_from_hub = False
|
54 |
+
try:
|
55 |
+
load_dataset(d.path, streaming=True)
|
56 |
+
ds_from_hub = True
|
57 |
+
except FileNotFoundError:
|
58 |
+
pass
|
59 |
+
|
60 |
+
# prefer local dataset, even if hub exists
|
61 |
+
if Path(d.path).exists():
|
62 |
+
ds: IterableDataset = load_dataset(
|
63 |
+
"json", data_files=d.path, streaming=True, split=None
|
64 |
+
)
|
65 |
+
elif ds_from_hub:
|
66 |
+
ds = load_dataset(d.path, streaming=True)
|
67 |
+
else:
|
68 |
+
raise Exception("unhandled dataset load")
|
69 |
+
|
70 |
+
if d.type == "alpaca":
|
71 |
+
ds_strategy = AlpacaPromptTokenizingStrategy(
|
72 |
+
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
73 |
+
)
|
74 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
75 |
+
datasets.append(ds_wrapper)
|
76 |
+
elif d.type == "oasst":
|
77 |
+
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
78 |
+
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
79 |
+
)
|
80 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
81 |
+
datasets.append(ds_wrapper)
|
82 |
+
elif d.type == "gpteacher":
|
83 |
+
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
84 |
+
GPTeacherPrompter(),
|
85 |
+
tokenizer,
|
86 |
+
cfg.train_on_inputs,
|
87 |
+
cfg.sequence_len,
|
88 |
+
)
|
89 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
90 |
+
datasets.append(ds_wrapper)
|
91 |
+
elif d.type == "reflection":
|
92 |
+
ds_strategy = AlpacaReflectionPTStrategy(
|
93 |
+
ReflectAlpacaPrompter(),
|
94 |
+
tokenizer,
|
95 |
+
cfg.train_on_inputs,
|
96 |
+
cfg.sequence_len,
|
97 |
+
)
|
98 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
99 |
+
datasets.append(ds_wrapper)
|
100 |
+
elif d.type == "sharegpt":
|
101 |
+
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
102 |
+
ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
103 |
+
)
|
104 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
105 |
+
datasets.append(ds_wrapper)
|
106 |
+
else:
|
107 |
+
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
108 |
+
constant_len_dataset = ConstantLengthDataset(
|
109 |
+
tokenizer,
|
110 |
+
datasets,
|
111 |
+
seq_length=max_packed_sequence_len,
|
112 |
+
)
|
113 |
+
logging.info("merging, packing, shuffling, and splitting master dataset")
|
114 |
+
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
|
115 |
+
test_size=cfg.val_set_size, shuffle=True, seed=42
|
116 |
+
)
|
117 |
+
|
118 |
+
if cfg.local_rank == 0:
|
119 |
+
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
|
120 |
+
dataset.save_to_disk(prepared_ds_path)
|
121 |
+
|
122 |
+
train_dataset = dataset["train"]
|
123 |
+
eval_dataset = dataset["test"]
|
124 |
+
|
125 |
+
return train_dataset, eval_dataset
|
src/axolotl/utils/models.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Tuple, TYPE_CHECKING
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
from transformers import (
|
9 |
+
AutoModelForCausalLM,
|
10 |
+
LlamaForCausalLM,
|
11 |
+
LlamaTokenizer,
|
12 |
+
AutoTokenizer,
|
13 |
+
PreTrainedModel,
|
14 |
+
)
|
15 |
+
|
16 |
+
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
17 |
+
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
from peft import PeftModel, PeftConfig
|
20 |
+
from attrdict import AttrDefault
|
21 |
+
from transformers import PreTrainedTokenizer
|
22 |
+
|
23 |
+
|
24 |
+
def load_model(
|
25 |
+
base_model,
|
26 |
+
base_model_config,
|
27 |
+
model_type,
|
28 |
+
tokenizer_type,
|
29 |
+
cfg,
|
30 |
+
adapter="lora",
|
31 |
+
inference=False,
|
32 |
+
):
|
33 |
+
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
34 |
+
|
35 |
+
# TODO refactor as a kwarg
|
36 |
+
load_in_8bit = cfg.load_in_8bit
|
37 |
+
tokenizer = None
|
38 |
+
is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
|
39 |
+
|
40 |
+
if is_llama_derived_model and cfg.flash_attention:
|
41 |
+
if cfg.device not in ["mps", "cpu"] and inference is False:
|
42 |
+
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
43 |
+
|
44 |
+
logging.info("patching with flash attention")
|
45 |
+
replace_llama_attn_with_flash_attn()
|
46 |
+
|
47 |
+
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
|
48 |
+
try:
|
49 |
+
if cfg.load_4bit:
|
50 |
+
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
51 |
+
replace_peft_model_with_int4_lora_model,
|
52 |
+
)
|
53 |
+
|
54 |
+
replace_peft_model_with_int4_lora_model()
|
55 |
+
from peft import prepare_model_for_int8_training
|
56 |
+
except Exception as e:
|
57 |
+
logging.exception(e)
|
58 |
+
raise e
|
59 |
+
|
60 |
+
try:
|
61 |
+
if cfg.load_4bit and is_llama_derived_model:
|
62 |
+
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
63 |
+
from huggingface_hub import snapshot_download
|
64 |
+
|
65 |
+
cache_model_path = Path(snapshot_download(base_model))
|
66 |
+
files = (
|
67 |
+
list(cache_model_path.glob("*.pt"))
|
68 |
+
+ list(cache_model_path.glob("*.safetensors"))
|
69 |
+
+ list(cache_model_path.glob("*.bin"))
|
70 |
+
)
|
71 |
+
if len(files) > 0:
|
72 |
+
model_path = str(files[0])
|
73 |
+
else:
|
74 |
+
logging.warning(
|
75 |
+
"unable to find a cached model file, this will likely fail..."
|
76 |
+
)
|
77 |
+
model_path = str(cache_model_path)
|
78 |
+
model, tokenizer = load_llama_model_4bit_low_ram(
|
79 |
+
base_model_config if base_model_config else base_model,
|
80 |
+
model_path,
|
81 |
+
device_map=cfg.device_map,
|
82 |
+
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
|
83 |
+
is_v1_model=cfg.gptq_model_v1
|
84 |
+
if cfg.gptq_model_v1 is not None
|
85 |
+
else True,
|
86 |
+
)
|
87 |
+
load_in_8bit = False
|
88 |
+
elif is_llama_derived_model:
|
89 |
+
model = LlamaForCausalLM.from_pretrained(
|
90 |
+
base_model,
|
91 |
+
load_in_8bit=cfg.load_in_8bit,
|
92 |
+
torch_dtype=torch_dtype,
|
93 |
+
device_map=cfg.device_map,
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
model = getattr(transformers, model_type).from_pretrained(
|
97 |
+
base_model,
|
98 |
+
load_in_8bit=cfg.load_in_8bit,
|
99 |
+
torch_dtype=torch_dtype,
|
100 |
+
device_map=cfg.device_map,
|
101 |
+
)
|
102 |
+
except Exception as e:
|
103 |
+
logging.error(
|
104 |
+
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
105 |
+
)
|
106 |
+
logging.exception(e)
|
107 |
+
model = AutoModelForCausalLM.from_pretrained(
|
108 |
+
base_model,
|
109 |
+
load_in_8bit=cfg.load_in_8bit,
|
110 |
+
torch_dtype=torch_dtype,
|
111 |
+
device_map=cfg.device_map,
|
112 |
+
)
|
113 |
+
|
114 |
+
if not tokenizer:
|
115 |
+
try:
|
116 |
+
if is_llama_derived_model:
|
117 |
+
tokenizer = LlamaTokenizer.from_pretrained(model)
|
118 |
+
else:
|
119 |
+
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
120 |
+
except:
|
121 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
122 |
+
|
123 |
+
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
124 |
+
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
125 |
+
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
126 |
+
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
127 |
+
|
128 |
+
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
129 |
+
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
130 |
+
|
131 |
+
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
132 |
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
133 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
134 |
+
|
135 |
+
if load_in_8bit and not cfg.load_4bit:
|
136 |
+
logging.info("converting model w/ prepare_model_for_int8_training")
|
137 |
+
model = prepare_model_for_int8_training(model)
|
138 |
+
|
139 |
+
model, lora_config = load_adapter(model, cfg, adapter)
|
140 |
+
|
141 |
+
if cfg.ddp:
|
142 |
+
model.to(f"cuda:{cfg.local_rank}")
|
143 |
+
|
144 |
+
if cfg.load_4bit:
|
145 |
+
# Scales to half
|
146 |
+
logging.info("Fitting 4bit scales and zeros to half")
|
147 |
+
for n, m in model.named_modules():
|
148 |
+
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
|
149 |
+
type(m)
|
150 |
+
):
|
151 |
+
if hasattr(m, "is_v1_model") and m.is_v1_model:
|
152 |
+
m.zeros = m.zeros.half()
|
153 |
+
m.scales = m.scales.half()
|
154 |
+
m.bias = m.bias.half()
|
155 |
+
|
156 |
+
# TODO resume_from_checkpoint handling
|
157 |
+
return model, tokenizer, lora_config
|
158 |
+
|
159 |
+
|
160 |
+
def load_adapter(model, cfg, adapter):
|
161 |
+
# type: (PreTrainedModel, AttrDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
162 |
+
|
163 |
+
if adapter is None:
|
164 |
+
return model, None
|
165 |
+
if adapter == "lora":
|
166 |
+
return load_lora(model, cfg)
|
167 |
+
# TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls
|
168 |
+
|
169 |
+
raise NotImplementedError(f"{adapter} peft adapter not available")
|
170 |
+
|
171 |
+
|
172 |
+
def load_lora(model, cfg):
|
173 |
+
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
174 |
+
|
175 |
+
from peft import (
|
176 |
+
LoraConfig,
|
177 |
+
get_peft_model,
|
178 |
+
PeftModel,
|
179 |
+
)
|
180 |
+
|
181 |
+
lora_config = None
|
182 |
+
|
183 |
+
if cfg.adapter == "lora":
|
184 |
+
lora_config = LoraConfig(
|
185 |
+
r=cfg.lora_r,
|
186 |
+
lora_alpha=cfg.lora_alpha,
|
187 |
+
target_modules=cfg.lora_target_modules,
|
188 |
+
lora_dropout=cfg.lora_dropout,
|
189 |
+
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
190 |
+
bias="none",
|
191 |
+
task_type="CAUSAL_LM",
|
192 |
+
)
|
193 |
+
|
194 |
+
if cfg.lora_model_dir:
|
195 |
+
model = PeftModel.from_pretrained(
|
196 |
+
model,
|
197 |
+
cfg.lora_model_dir,
|
198 |
+
device_map=cfg.device_map,
|
199 |
+
torch_dtype=torch.float16,
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
model = get_peft_model(model, lora_config)
|
203 |
+
|
204 |
+
model.print_trainable_parameters()
|
205 |
+
|
206 |
+
return model, lora_config
|
src/axolotl/utils/trainer.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import bitsandbytes as bnb
|
3 |
+
import transformers
|
4 |
+
from torch import nn
|
5 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
6 |
+
from transformers import EarlyStoppingCallback
|
7 |
+
from transformers.trainer_pt_utils import get_parameter_names
|
8 |
+
|
9 |
+
|
10 |
+
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
11 |
+
total_num_steps = int(
|
12 |
+
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
13 |
+
)
|
14 |
+
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
15 |
+
logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
|
16 |
+
save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
|
17 |
+
|
18 |
+
training_arguments_kwargs = {}
|
19 |
+
if cfg.bf16 == "full":
|
20 |
+
training_arguments_kwargs["bf16_full_eval"] = True
|
21 |
+
else:
|
22 |
+
training_arguments_kwargs["bf16"] = cfg.bf16
|
23 |
+
training_arguments_kwargs["tf32"] = cfg.tf32
|
24 |
+
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
25 |
+
training_arguments_kwargs["logging_steps"] = logging_steps
|
26 |
+
if cfg.gradient_checkpointing is not None:
|
27 |
+
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
28 |
+
|
29 |
+
training_args = transformers.TrainingArguments(
|
30 |
+
per_device_train_batch_size=cfg.micro_batch_size,
|
31 |
+
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
32 |
+
num_train_epochs=cfg.num_epochs,
|
33 |
+
learning_rate=cfg.learning_rate,
|
34 |
+
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
35 |
+
save_strategy="steps",
|
36 |
+
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
|
37 |
+
save_steps=save_steps,
|
38 |
+
output_dir=cfg.output_dir,
|
39 |
+
save_total_limit=3,
|
40 |
+
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
|
41 |
+
ddp_find_unused_parameters=False if cfg.ddp else None,
|
42 |
+
group_by_length=cfg.group_by_length,
|
43 |
+
report_to="wandb" if cfg.use_wandb else None,
|
44 |
+
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
45 |
+
**training_arguments_kwargs,
|
46 |
+
)
|
47 |
+
|
48 |
+
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
49 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
50 |
+
optimizer_grouped_parameters = [
|
51 |
+
{
|
52 |
+
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
53 |
+
"weight_decay": training_args.weight_decay,
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"params": [
|
57 |
+
p for n, p in model.named_parameters() if n not in decay_parameters
|
58 |
+
],
|
59 |
+
"weight_decay": 0.0,
|
60 |
+
},
|
61 |
+
]
|
62 |
+
|
63 |
+
trainer_kwargs = {}
|
64 |
+
|
65 |
+
if cfg.load_in_8bit and not cfg.load_4bit:
|
66 |
+
optimizer = bnb.optim.Adam8bit(
|
67 |
+
optimizer_grouped_parameters,
|
68 |
+
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
69 |
+
eps=training_args.adam_epsilon,
|
70 |
+
lr=training_args.learning_rate,
|
71 |
+
)
|
72 |
+
|
73 |
+
if cfg.lr_scheduler == "one_cycle":
|
74 |
+
lr_scheduler_kwargs = (
|
75 |
+
cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
|
76 |
+
)
|
77 |
+
lr_scheduler = OneCycleLR(
|
78 |
+
optimizer,
|
79 |
+
cfg.learning_rate,
|
80 |
+
total_steps=total_num_steps,
|
81 |
+
**lr_scheduler_kwargs,
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
|
85 |
+
optimizer,
|
86 |
+
training_args.warmup_steps,
|
87 |
+
total_num_steps,
|
88 |
+
)
|
89 |
+
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
|
90 |
+
|
91 |
+
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
92 |
+
if cfg.early_stopping_patience:
|
93 |
+
early_stop_cb = EarlyStoppingCallback(
|
94 |
+
cfg.early_stopping_patience,
|
95 |
+
)
|
96 |
+
trainer_kwargs["callbacks"] = [early_stop_cb]
|
97 |
+
|
98 |
+
trainer = transformers.Trainer(
|
99 |
+
model=model,
|
100 |
+
train_dataset=train_dataset,
|
101 |
+
eval_dataset=eval_dataset,
|
102 |
+
args=training_args,
|
103 |
+
data_collator=transformers.DataCollatorForSeq2Seq(
|
104 |
+
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
105 |
+
),
|
106 |
+
**trainer_kwargs,
|
107 |
+
)
|
108 |
+
|
109 |
+
return trainer
|
src/axolotl/utils/wandb.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def setup_wandb_env_vars(cfg):
|
5 |
+
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
6 |
+
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
7 |
+
cfg.use_wandb = True
|
8 |
+
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
9 |
+
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
10 |
+
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
11 |
+
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
12 |
+
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
13 |
+
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|