winglian commited on
Commit
6045345
1 Parent(s): 81de0ef

WIP large refactor to make finetune script a little more manageable (#3)

Browse files
configs/gpt_neox_20b.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: EleutherAI/gpt-neox-20b
2
+ base_model_ignore_patterns: pytorch* # prefer safetensors
3
+ model_type: GPTNeoXForCausalLM
4
+ tokenizer_type: AutoTokenizer
5
+ load_in_8bit: true
6
+ datasets:
7
+ - path: nomic-ai/gpt4all-j-prompt-generations
8
+ type: alpaca
9
+ shards: 4
10
+ shards_index: 0
11
+ dataset_prepared_path: last_run_prepared
12
+ val_set_size: 0.05
13
+ adapter: lora
14
+ lora_model_dir:
15
+ sequence_len: 2048
16
+ max_packed_sequence_len: 2048
17
+ lora_r: 8
18
+ lora_alpha: 32
19
+ lora_dropout: 0.05
20
+ lora_target_modules:
21
+ - query_key_value
22
+ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
23
+ wandb_project: gpt4all-neox-20b
24
+ wandb_watch:
25
+ wandb_run_id:
26
+ wandb_log_model: checkpoint
27
+ output_dir: ./gpt4all-neox-20b
28
+ batch_size: 48
29
+ micro_batch_size: 4
30
+ num_epochs: 5
31
+ learning_rate: 0.00003
32
+ lr_scheduler: one_cycle
33
+ train_on_inputs: false
34
+ group_by_length: false
35
+ bf16: True
36
+ tf32: True
37
+ early_stopping_patience:
38
+ resume_from_checkpoint:
39
+ local_rank:
scripts/finetune.py CHANGED
@@ -1,223 +1,29 @@
1
  import logging
2
- import math
3
  import os
4
  import random
5
  import signal
6
  import sys
7
- from hashlib import md5
8
  from pathlib import Path
9
 
10
- import bitsandbytes as bnb
11
  import fire
12
  import torch
13
- import transformers
14
  import yaml
15
  from attrdict import AttrDefault
16
- from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
17
- from torch import nn
18
- from transformers import (
19
- AutoModelForCausalLM,
20
- AutoTokenizer,
21
- LlamaForCausalLM,
22
- LlamaTokenizer,
23
- EarlyStoppingCallback,
24
- GenerationConfig,
25
- )
26
 
27
  # add src to the pythonpath so we don't need to pip install this
28
- from transformers.trainer_pt_utils import get_parameter_names
29
-
30
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
31
  src_dir = os.path.join(project_root, "src")
32
  sys.path.insert(0, src_dir)
33
 
34
- from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
35
- from axolotl.prompt_tokenizers import (
36
- AlpacaPromptTokenizingStrategy,
37
- ShareGPTPromptTokenizingStrategy,
38
- LLAMA_DEFAULT_PAD_TOKEN,
39
- GPTeacherPromptTokenizingStrategy,
40
- OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy,
41
- )
42
- from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter
43
 
44
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
45
  DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
46
 
47
 
48
- def setup_wandb_env_vars(cfg):
49
- if cfg.wandb_project and len(cfg.wandb_project) > 0:
50
- os.environ["WANDB_PROJECT"] = cfg.wandb_project
51
- cfg.use_wandb = True
52
- if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
53
- os.environ["WANDB_WATCH"] = cfg.wandb_watch
54
- if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
55
- os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
56
- if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
57
- os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
58
-
59
-
60
- def load_model(
61
- base_model,
62
- base_model_config,
63
- model_type,
64
- tokenizer_type,
65
- cfg,
66
- adapter="lora",
67
- inference: bool = False,
68
- ):
69
- # TODO refactor as a kwarg
70
- load_in_8bit = cfg.load_in_8bit
71
- tokenizer = None
72
- is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
73
-
74
- if adapter != "lora":
75
- raise NotImplementedError(f"{adapter} peft adapter not available")
76
- if is_llama_derived_model and cfg.flash_attention:
77
- if cfg.device not in ["mps", "cpu"] and inference is False:
78
- from axolotl.flash_attn import replace_llama_attn_with_flash_attn
79
-
80
- logging.info("patching with flash attention")
81
- replace_llama_attn_with_flash_attn()
82
-
83
- torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
84
- try:
85
- if cfg.load_4bit:
86
- from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
87
- replace_peft_model_with_int4_lora_model,
88
- )
89
-
90
- replace_peft_model_with_int4_lora_model()
91
-
92
- from peft import (
93
- LoraConfig,
94
- get_peft_model,
95
- prepare_model_for_int8_training,
96
- PeftModel,
97
- )
98
- except Exception as e:
99
- logging.exception(e)
100
- raise e
101
-
102
- try:
103
- if cfg.load_4bit and is_llama_derived_model:
104
- from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
105
- from huggingface_hub import snapshot_download
106
-
107
- cache_model_path = Path(snapshot_download(base_model))
108
- files = (
109
- list(cache_model_path.glob("*.pt"))
110
- + list(cache_model_path.glob("*.safetensors"))
111
- + list(cache_model_path.glob("*.bin"))
112
- )
113
- if len(files) > 0:
114
- model_path = str(files[0])
115
- else:
116
- logging.warning(
117
- "unable to find a cached model file, this will likely fail..."
118
- )
119
- model_path = str(cache_model_path)
120
- model, tokenizer = load_llama_model_4bit_low_ram(
121
- base_model_config if base_model_config else base_model,
122
- model_path,
123
- device_map=cfg.device_map,
124
- groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
125
- is_v1_model=cfg.gptq_model_v1
126
- if cfg.gptq_model_v1 is not None
127
- else True,
128
- )
129
- load_in_8bit = False
130
- elif is_llama_derived_model:
131
- model = LlamaForCausalLM.from_pretrained(
132
- base_model,
133
- load_in_8bit=cfg.load_in_8bit,
134
- torch_dtype=torch_dtype,
135
- device_map=cfg.device_map,
136
- )
137
- else:
138
- model = getattr(transformers, model_type).from_pretrained(
139
- base_model,
140
- load_in_8bit=cfg.load_in_8bit,
141
- torch_dtype=torch_dtype,
142
- device_map=cfg.device_map,
143
- )
144
- except Exception as e:
145
- logging.error(
146
- "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
147
- )
148
- logging.exception(e)
149
- model = AutoModelForCausalLM.from_pretrained(
150
- base_model,
151
- load_in_8bit=cfg.load_in_8bit,
152
- torch_dtype=torch_dtype,
153
- device_map=cfg.device_map,
154
- )
155
-
156
- if not tokenizer:
157
- try:
158
- if is_llama_derived_model:
159
- tokenizer = LlamaTokenizer.from_pretrained(model)
160
- else:
161
- tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
162
- except:
163
- tokenizer = AutoTokenizer.from_pretrained(base_model)
164
-
165
- logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
166
- logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
167
- logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
168
- logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
169
-
170
- if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
171
- tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
172
-
173
- if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
174
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
175
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
176
-
177
- if load_in_8bit and not cfg.load_4bit:
178
- logging.info("converting model w/ prepare_model_for_int8_training")
179
- model = prepare_model_for_int8_training(model)
180
-
181
- lora_config = LoraConfig(
182
- r=cfg.lora_r,
183
- lora_alpha=cfg.lora_alpha,
184
- target_modules=cfg.lora_target_modules,
185
- lora_dropout=cfg.lora_dropout,
186
- fan_in_fan_out=cfg.lora_fan_in_fan_out,
187
- bias="none",
188
- task_type="CAUSAL_LM",
189
- )
190
-
191
- if cfg.lora_model_dir:
192
- model = PeftModel.from_pretrained(
193
- model,
194
- cfg.lora_model_dir,
195
- device_map=cfg.device_map,
196
- torch_dtype=torch.float16,
197
- )
198
- else:
199
- model = get_peft_model(model, lora_config)
200
-
201
- if cfg.ddp:
202
- model.to(f"cuda:{cfg.local_rank}")
203
-
204
- if cfg.load_4bit:
205
- # Scales to half
206
- logging.info("Fitting 4bit scales and zeros to half")
207
- for n, m in model.named_modules():
208
- if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
209
- type(m)
210
- ):
211
- if hasattr(m, "is_v1_model") and m.is_v1_model:
212
- m.zeros = m.zeros.half()
213
- m.scales = m.scales.half()
214
- m.bias = m.bias.half()
215
-
216
- # TODO resume_from_checkpoint handling
217
- model.print_trainable_parameters()
218
- return model, tokenizer, lora_config
219
-
220
-
221
  def choose_device(cfg):
222
  def get_device():
223
  if torch.cuda.is_available():
@@ -271,11 +77,13 @@ def do_inference(cfg, model, tokenizer):
271
  tokenizer.add_special_tokens({"bos_token": "<s>"})
272
  tokenizer.add_special_tokens({"eos_token": "</s>"})
273
 
274
- instruction = "Tell me a joke about dromedaries."
275
- input = ""
276
- prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(
277
- instruction=instruction, input=input
 
278
  )
 
279
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
280
 
281
  model.eval()
@@ -324,98 +132,6 @@ def choose_config(path: Path):
324
  return chosen_file
325
 
326
 
327
- def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
328
- total_num_steps = int(
329
- math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
330
- )
331
- warmup_steps = min(int(0.03 * total_num_steps), 100)
332
- logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
333
- save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
334
-
335
- training_arguments_kwargs = {}
336
- if cfg.bf16 == "full":
337
- training_arguments_kwargs["bf16_full_eval"] = True
338
- else:
339
- training_arguments_kwargs["bf16"] = cfg.bf16
340
- training_arguments_kwargs["tf32"] = cfg.tf32
341
- training_arguments_kwargs["warmup_steps"] = warmup_steps
342
- training_arguments_kwargs["logging_steps"] = logging_steps
343
- if cfg.gradient_checkpointing is not None:
344
- training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
345
-
346
- training_args = transformers.TrainingArguments(
347
- per_device_train_batch_size=cfg.micro_batch_size,
348
- gradient_accumulation_steps=cfg.gradient_accumulation_steps,
349
- num_train_epochs=cfg.num_epochs,
350
- learning_rate=cfg.learning_rate,
351
- evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
352
- save_strategy="steps",
353
- eval_steps=eval_steps if cfg.val_set_size > 0 else None,
354
- save_steps=save_steps,
355
- output_dir=cfg.output_dir,
356
- save_total_limit=3,
357
- load_best_model_at_end=True if cfg.val_set_size > 0 else False,
358
- ddp_find_unused_parameters=False if cfg.ddp else None,
359
- group_by_length=cfg.group_by_length,
360
- report_to="wandb" if cfg.use_wandb else None,
361
- run_name=cfg.wandb_run_id if cfg.use_wandb else None,
362
- **training_arguments_kwargs,
363
- )
364
-
365
- decay_parameters = get_parameter_names(model, [nn.LayerNorm])
366
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
367
- optimizer_grouped_parameters = [
368
- {
369
- "params": [p for n, p in model.named_parameters() if n in decay_parameters],
370
- "weight_decay": training_args.weight_decay,
371
- },
372
- {
373
- "params": [
374
- p for n, p in model.named_parameters() if n not in decay_parameters
375
- ],
376
- "weight_decay": 0.0,
377
- },
378
- ]
379
-
380
- trainer_kwargs = {}
381
-
382
- if cfg.load_in_8bit and not cfg.load_4bit:
383
- adam_bnb_optim = bnb.optim.Adam8bit(
384
- optimizer_grouped_parameters,
385
- betas=(training_args.adam_beta1, training_args.adam_beta2),
386
- eps=training_args.adam_epsilon,
387
- lr=training_args.learning_rate,
388
- )
389
-
390
- # TODO optionally use torch.optim.OneCycleLR
391
- lr_scheduler = transformers.get_cosine_schedule_with_warmup(
392
- adam_bnb_optim,
393
- training_args.warmup_steps,
394
- total_num_steps,
395
- )
396
- trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
397
-
398
- # TODO on_save callback to sync checkpoints to GCP/AWS in background
399
- if cfg.early_stopping_patience:
400
- early_stop_cb = EarlyStoppingCallback(
401
- cfg.early_stopping_patience,
402
- )
403
- trainer_kwargs["callbacks"] = [early_stop_cb]
404
-
405
- trainer = transformers.Trainer(
406
- model=model,
407
- train_dataset=train_dataset,
408
- eval_dataset=eval_dataset,
409
- args=training_args,
410
- data_collator=transformers.DataCollatorForSeq2Seq(
411
- tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
412
- ),
413
- **trainer_kwargs,
414
- )
415
-
416
- return trainer
417
-
418
-
419
  def train(
420
  config: Path = Path("configs/"),
421
  prepare_ds_only: bool = False,
@@ -474,110 +190,13 @@ def train(
474
  do_inference(cfg, model, tokenizer)
475
  return
476
 
477
- max_packed_sequence_len = (
478
- cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
479
- )
480
- max_packed_sequence_len = min(
481
- max_packed_sequence_len, cfg.sequence_len
482
- ) # make sure we don't accidentally set it larger than sequence_len
483
- ds_hash = str(
484
- md5(
485
- (
486
- str(max_packed_sequence_len)
487
- + "@"
488
- + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
489
- ).encode("utf-8")
490
- ).hexdigest()
491
- )
492
- prepared_ds_path = (
493
- Path(cfg.dataset_prepared_path) / ds_hash
494
- if cfg.dataset_prepared_path
495
- else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
496
  )
497
 
498
- if any(prepared_ds_path.glob("*")):
499
- logging.info("Loading prepared dataset from disk...")
500
- dataset = load_from_disk(str(prepared_ds_path))
501
- logging.info("Prepared dataset loaded from disk...")
502
- else:
503
- logging.info("Loading raw datasets...")
504
- datasets = []
505
- for d in cfg.datasets:
506
- ds_from_hub = False
507
- try:
508
- load_dataset(d.path, streaming=True)
509
- ds_from_hub = True
510
- except FileNotFoundError:
511
- pass
512
-
513
- # prefer local dataset, even if hub exists
514
- if Path(d.path).exists():
515
- ds: IterableDataset = load_dataset(
516
- "json", data_files=d.path, streaming=True, split=None
517
- )
518
- elif ds_from_hub:
519
- ds = load_dataset(d.path, streaming=True)
520
- else:
521
- raise Exception("unhandled dataset load")
522
-
523
- if d.type == "alpaca":
524
- ds_strategy = AlpacaPromptTokenizingStrategy(
525
- AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
526
- )
527
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
528
- datasets.append(ds_wrapper)
529
- elif d.type == "oasst":
530
- ds_strategy = OpenAssistantPromptTokenizingStrategy(
531
- AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
532
- )
533
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
534
- datasets.append(ds_wrapper)
535
- elif d.type == "gpteacher":
536
- ds_strategy = GPTeacherPromptTokenizingStrategy(
537
- GPTeacherPrompter(),
538
- tokenizer,
539
- cfg.train_on_inputs,
540
- cfg.sequence_len,
541
- )
542
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
543
- datasets.append(ds_wrapper)
544
- elif d.type == "reflection":
545
- ds_strategy = AlpacaReflectionPTStrategy(
546
- ReflectAlpacaPrompter(),
547
- tokenizer,
548
- cfg.train_on_inputs,
549
- cfg.sequence_len,
550
- )
551
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
552
- datasets.append(ds_wrapper)
553
- elif d.type == "sharegpt":
554
- ds_strategy = ShareGPTPromptTokenizingStrategy(
555
- ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
556
- )
557
- ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
558
- datasets.append(ds_wrapper)
559
- else:
560
- logging.error(f"unhandled prompt tokenization strategy: {d.type}")
561
- constant_len_dataset = ConstantLengthDataset(
562
- tokenizer,
563
- datasets,
564
- seq_length=max_packed_sequence_len,
565
- )
566
- logging.info("merging, packing, shuffling, and splitting master dataset")
567
- dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
568
- test_size=cfg.val_set_size, shuffle=True, seed=42
569
- )
570
-
571
- if cfg.local_rank == 0:
572
- logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
573
- dataset.save_to_disk(prepared_ds_path)
574
-
575
- if prepare_ds_only:
576
- logging.info("Finished preparing dataset. Exiting...")
577
- return
578
-
579
- train_dataset = dataset["train"]
580
- eval_dataset = dataset["test"]
581
 
582
  if cfg.debug:
583
  check_dataset_labels(
@@ -594,8 +213,9 @@ def train(
594
  model = torch.compile(model)
595
 
596
  # go ahead and presave, so we have the adapter config available to inspect
597
- logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
598
- lora_config.save_pretrained(cfg.output_dir)
 
599
 
600
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
601
  if cfg.local_rank == 0:
 
1
  import logging
 
2
  import os
3
  import random
4
  import signal
5
  import sys
 
6
  from pathlib import Path
7
 
 
8
  import fire
9
  import torch
 
10
  import yaml
11
  from attrdict import AttrDefault
 
 
 
 
 
 
 
 
 
 
12
 
13
  # add src to the pythonpath so we don't need to pip install this
 
 
14
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
15
  src_dir = os.path.join(project_root, "src")
16
  sys.path.insert(0, src_dir)
17
 
18
+ from axolotl.utils.data import load_prepare_datasets
19
+ from axolotl.utils.models import load_model
20
+ from axolotl.utils.trainer import setup_trainer
21
+ from axolotl.utils.wandb import setup_wandb_env_vars
 
 
 
 
 
22
 
23
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
24
  DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def choose_device(cfg):
28
  def get_device():
29
  if torch.cuda.is_available():
 
77
  tokenizer.add_special_tokens({"bos_token": "<s>"})
78
  tokenizer.add_special_tokens({"eos_token": "</s>"})
79
 
80
+ from axolotl.prompters import ReflectAlpacaPrompter
81
+
82
+ instruction = str(input("Give me an instruction: "))
83
+ instruction = (
84
+ instruction if not instruction else "Tell me a joke about dromedaries."
85
  )
86
+ prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
87
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
88
 
89
  model.eval()
 
132
  return chosen_file
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def train(
136
  config: Path = Path("configs/"),
137
  prepare_ds_only: bool = False,
 
190
  do_inference(cfg, model, tokenizer)
191
  return
192
 
193
+ train_dataset, eval_dataset = load_prepare_datasets(
194
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
 
197
+ if prepare_ds_only:
198
+ logging.info("Finished preparing dataset. Exiting...")
199
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  if cfg.debug:
202
  check_dataset_labels(
 
213
  model = torch.compile(model)
214
 
215
  # go ahead and presave, so we have the adapter config available to inspect
216
+ if lora_config:
217
+ logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
218
+ lora_config.save_pretrained(cfg.output_dir)
219
 
220
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
221
  if cfg.local_rank == 0:
src/axolotl/prompt_tokenizers.py CHANGED
@@ -107,6 +107,15 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
107
  )
108
 
109
 
 
 
 
 
 
 
 
 
 
110
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
111
  def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
112
  raise NotImplementedError
@@ -168,6 +177,7 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
168
  prompt["corrected"],
169
  )
170
 
 
171
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
172
  def tokenize_prompt(self, prompt):
173
  try:
 
107
  )
108
 
109
 
110
+ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
111
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
112
+ return (
113
+ prompt["prompt"],
114
+ "",
115
+ prompt["response"],
116
+ )
117
+
118
+
119
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
120
  def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
121
  raise NotImplementedError
 
177
  prompt["corrected"],
178
  )
179
 
180
+
181
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
182
  def tokenize_prompt(self, prompt):
183
  try:
src/axolotl/prompters.py CHANGED
@@ -35,6 +35,10 @@ class GPTeacherPrompter(AlpacaPrompter):
35
  ...
36
 
37
 
 
 
 
 
38
  class ReflectAlpacaPrompter:
39
  prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
40
  prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
 
35
  ...
36
 
37
 
38
+ class NomicGPT4AllPrompter(AlpacaPrompter):
39
+ ...
40
+
41
+
42
  class ReflectAlpacaPrompter:
43
  prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
44
  prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
src/axolotl/utils/__init__.py ADDED
File without changes
src/axolotl/utils/data.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from hashlib import md5
3
+ from pathlib import Path
4
+
5
+ from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
6
+
7
+ from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
8
+ from axolotl.prompt_tokenizers import (
9
+ AlpacaPromptTokenizingStrategy,
10
+ GPTeacherPromptTokenizingStrategy,
11
+ OpenAssistantPromptTokenizingStrategy,
12
+ AlpacaReflectionPTStrategy,
13
+ ShareGPTPromptTokenizingStrategy,
14
+ )
15
+ from axolotl.prompters import (
16
+ AlpacaPrompter,
17
+ GPTeacherPrompter,
18
+ ReflectAlpacaPrompter,
19
+ ShareGPTPrompter,
20
+ )
21
+
22
+
23
+ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
24
+ max_packed_sequence_len = (
25
+ cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
26
+ )
27
+ max_packed_sequence_len = min(
28
+ max_packed_sequence_len, cfg.sequence_len
29
+ ) # make sure we don't accidentally set it larger than sequence_len
30
+ ds_hash = str(
31
+ md5(
32
+ (
33
+ str(max_packed_sequence_len)
34
+ + "@"
35
+ + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
36
+ ).encode("utf-8")
37
+ ).hexdigest()
38
+ )
39
+ prepared_ds_path = (
40
+ Path(cfg.dataset_prepared_path) / ds_hash
41
+ if cfg.dataset_prepared_path
42
+ else Path(default_dataset_prepared_path) / ds_hash
43
+ )
44
+
45
+ if any(prepared_ds_path.glob("*")):
46
+ logging.info("Loading prepared dataset from disk...")
47
+ dataset = load_from_disk(str(prepared_ds_path))
48
+ logging.info("Prepared dataset loaded from disk...")
49
+ else:
50
+ logging.info("Loading raw datasets...")
51
+ datasets = []
52
+ for d in cfg.datasets:
53
+ ds_from_hub = False
54
+ try:
55
+ load_dataset(d.path, streaming=True)
56
+ ds_from_hub = True
57
+ except FileNotFoundError:
58
+ pass
59
+
60
+ # prefer local dataset, even if hub exists
61
+ if Path(d.path).exists():
62
+ ds: IterableDataset = load_dataset(
63
+ "json", data_files=d.path, streaming=True, split=None
64
+ )
65
+ elif ds_from_hub:
66
+ ds = load_dataset(d.path, streaming=True)
67
+ else:
68
+ raise Exception("unhandled dataset load")
69
+
70
+ if d.type == "alpaca":
71
+ ds_strategy = AlpacaPromptTokenizingStrategy(
72
+ AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
73
+ )
74
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
75
+ datasets.append(ds_wrapper)
76
+ elif d.type == "oasst":
77
+ ds_strategy = OpenAssistantPromptTokenizingStrategy(
78
+ AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
79
+ )
80
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
81
+ datasets.append(ds_wrapper)
82
+ elif d.type == "gpteacher":
83
+ ds_strategy = GPTeacherPromptTokenizingStrategy(
84
+ GPTeacherPrompter(),
85
+ tokenizer,
86
+ cfg.train_on_inputs,
87
+ cfg.sequence_len,
88
+ )
89
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
90
+ datasets.append(ds_wrapper)
91
+ elif d.type == "reflection":
92
+ ds_strategy = AlpacaReflectionPTStrategy(
93
+ ReflectAlpacaPrompter(),
94
+ tokenizer,
95
+ cfg.train_on_inputs,
96
+ cfg.sequence_len,
97
+ )
98
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
99
+ datasets.append(ds_wrapper)
100
+ elif d.type == "sharegpt":
101
+ ds_strategy = ShareGPTPromptTokenizingStrategy(
102
+ ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
103
+ )
104
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
105
+ datasets.append(ds_wrapper)
106
+ else:
107
+ logging.error(f"unhandled prompt tokenization strategy: {d.type}")
108
+ constant_len_dataset = ConstantLengthDataset(
109
+ tokenizer,
110
+ datasets,
111
+ seq_length=max_packed_sequence_len,
112
+ )
113
+ logging.info("merging, packing, shuffling, and splitting master dataset")
114
+ dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
115
+ test_size=cfg.val_set_size, shuffle=True, seed=42
116
+ )
117
+
118
+ if cfg.local_rank == 0:
119
+ logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
120
+ dataset.save_to_disk(prepared_ds_path)
121
+
122
+ train_dataset = dataset["train"]
123
+ eval_dataset = dataset["test"]
124
+
125
+ return train_dataset, eval_dataset
src/axolotl/utils/models.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple, TYPE_CHECKING
5
+
6
+ import torch
7
+ import transformers
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ LlamaForCausalLM,
11
+ LlamaTokenizer,
12
+ AutoTokenizer,
13
+ PreTrainedModel,
14
+ )
15
+
16
+ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
17
+
18
+ if TYPE_CHECKING:
19
+ from peft import PeftModel, PeftConfig
20
+ from attrdict import AttrDefault
21
+ from transformers import PreTrainedTokenizer
22
+
23
+
24
+ def load_model(
25
+ base_model,
26
+ base_model_config,
27
+ model_type,
28
+ tokenizer_type,
29
+ cfg,
30
+ adapter="lora",
31
+ inference=False,
32
+ ):
33
+ # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
34
+
35
+ # TODO refactor as a kwarg
36
+ load_in_8bit = cfg.load_in_8bit
37
+ tokenizer = None
38
+ is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
39
+
40
+ if is_llama_derived_model and cfg.flash_attention:
41
+ if cfg.device not in ["mps", "cpu"] and inference is False:
42
+ from axolotl.flash_attn import replace_llama_attn_with_flash_attn
43
+
44
+ logging.info("patching with flash attention")
45
+ replace_llama_attn_with_flash_attn()
46
+
47
+ torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
48
+ try:
49
+ if cfg.load_4bit:
50
+ from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
51
+ replace_peft_model_with_int4_lora_model,
52
+ )
53
+
54
+ replace_peft_model_with_int4_lora_model()
55
+ from peft import prepare_model_for_int8_training
56
+ except Exception as e:
57
+ logging.exception(e)
58
+ raise e
59
+
60
+ try:
61
+ if cfg.load_4bit and is_llama_derived_model:
62
+ from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
63
+ from huggingface_hub import snapshot_download
64
+
65
+ cache_model_path = Path(snapshot_download(base_model))
66
+ files = (
67
+ list(cache_model_path.glob("*.pt"))
68
+ + list(cache_model_path.glob("*.safetensors"))
69
+ + list(cache_model_path.glob("*.bin"))
70
+ )
71
+ if len(files) > 0:
72
+ model_path = str(files[0])
73
+ else:
74
+ logging.warning(
75
+ "unable to find a cached model file, this will likely fail..."
76
+ )
77
+ model_path = str(cache_model_path)
78
+ model, tokenizer = load_llama_model_4bit_low_ram(
79
+ base_model_config if base_model_config else base_model,
80
+ model_path,
81
+ device_map=cfg.device_map,
82
+ groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
83
+ is_v1_model=cfg.gptq_model_v1
84
+ if cfg.gptq_model_v1 is not None
85
+ else True,
86
+ )
87
+ load_in_8bit = False
88
+ elif is_llama_derived_model:
89
+ model = LlamaForCausalLM.from_pretrained(
90
+ base_model,
91
+ load_in_8bit=cfg.load_in_8bit,
92
+ torch_dtype=torch_dtype,
93
+ device_map=cfg.device_map,
94
+ )
95
+ else:
96
+ model = getattr(transformers, model_type).from_pretrained(
97
+ base_model,
98
+ load_in_8bit=cfg.load_in_8bit,
99
+ torch_dtype=torch_dtype,
100
+ device_map=cfg.device_map,
101
+ )
102
+ except Exception as e:
103
+ logging.error(
104
+ "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
105
+ )
106
+ logging.exception(e)
107
+ model = AutoModelForCausalLM.from_pretrained(
108
+ base_model,
109
+ load_in_8bit=cfg.load_in_8bit,
110
+ torch_dtype=torch_dtype,
111
+ device_map=cfg.device_map,
112
+ )
113
+
114
+ if not tokenizer:
115
+ try:
116
+ if is_llama_derived_model:
117
+ tokenizer = LlamaTokenizer.from_pretrained(model)
118
+ else:
119
+ tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
120
+ except:
121
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
122
+
123
+ logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
124
+ logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
125
+ logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
126
+ logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
127
+
128
+ if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
129
+ tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
130
+
131
+ if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
132
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
133
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
134
+
135
+ if load_in_8bit and not cfg.load_4bit:
136
+ logging.info("converting model w/ prepare_model_for_int8_training")
137
+ model = prepare_model_for_int8_training(model)
138
+
139
+ model, lora_config = load_adapter(model, cfg, adapter)
140
+
141
+ if cfg.ddp:
142
+ model.to(f"cuda:{cfg.local_rank}")
143
+
144
+ if cfg.load_4bit:
145
+ # Scales to half
146
+ logging.info("Fitting 4bit scales and zeros to half")
147
+ for n, m in model.named_modules():
148
+ if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
149
+ type(m)
150
+ ):
151
+ if hasattr(m, "is_v1_model") and m.is_v1_model:
152
+ m.zeros = m.zeros.half()
153
+ m.scales = m.scales.half()
154
+ m.bias = m.bias.half()
155
+
156
+ # TODO resume_from_checkpoint handling
157
+ return model, tokenizer, lora_config
158
+
159
+
160
+ def load_adapter(model, cfg, adapter):
161
+ # type: (PreTrainedModel, AttrDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
162
+
163
+ if adapter is None:
164
+ return model, None
165
+ if adapter == "lora":
166
+ return load_lora(model, cfg)
167
+ # TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls
168
+
169
+ raise NotImplementedError(f"{adapter} peft adapter not available")
170
+
171
+
172
+ def load_lora(model, cfg):
173
+ # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
174
+
175
+ from peft import (
176
+ LoraConfig,
177
+ get_peft_model,
178
+ PeftModel,
179
+ )
180
+
181
+ lora_config = None
182
+
183
+ if cfg.adapter == "lora":
184
+ lora_config = LoraConfig(
185
+ r=cfg.lora_r,
186
+ lora_alpha=cfg.lora_alpha,
187
+ target_modules=cfg.lora_target_modules,
188
+ lora_dropout=cfg.lora_dropout,
189
+ fan_in_fan_out=cfg.lora_fan_in_fan_out,
190
+ bias="none",
191
+ task_type="CAUSAL_LM",
192
+ )
193
+
194
+ if cfg.lora_model_dir:
195
+ model = PeftModel.from_pretrained(
196
+ model,
197
+ cfg.lora_model_dir,
198
+ device_map=cfg.device_map,
199
+ torch_dtype=torch.float16,
200
+ )
201
+ else:
202
+ model = get_peft_model(model, lora_config)
203
+
204
+ model.print_trainable_parameters()
205
+
206
+ return model, lora_config
src/axolotl/utils/trainer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import bitsandbytes as bnb
3
+ import transformers
4
+ from torch import nn
5
+ from torch.optim.lr_scheduler import OneCycleLR
6
+ from transformers import EarlyStoppingCallback
7
+ from transformers.trainer_pt_utils import get_parameter_names
8
+
9
+
10
+ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
11
+ total_num_steps = int(
12
+ math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
13
+ )
14
+ warmup_steps = min(int(0.03 * total_num_steps), 100)
15
+ logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
16
+ save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
17
+
18
+ training_arguments_kwargs = {}
19
+ if cfg.bf16 == "full":
20
+ training_arguments_kwargs["bf16_full_eval"] = True
21
+ else:
22
+ training_arguments_kwargs["bf16"] = cfg.bf16
23
+ training_arguments_kwargs["tf32"] = cfg.tf32
24
+ training_arguments_kwargs["warmup_steps"] = warmup_steps
25
+ training_arguments_kwargs["logging_steps"] = logging_steps
26
+ if cfg.gradient_checkpointing is not None:
27
+ training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
28
+
29
+ training_args = transformers.TrainingArguments(
30
+ per_device_train_batch_size=cfg.micro_batch_size,
31
+ gradient_accumulation_steps=cfg.gradient_accumulation_steps,
32
+ num_train_epochs=cfg.num_epochs,
33
+ learning_rate=cfg.learning_rate,
34
+ evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
35
+ save_strategy="steps",
36
+ eval_steps=eval_steps if cfg.val_set_size > 0 else None,
37
+ save_steps=save_steps,
38
+ output_dir=cfg.output_dir,
39
+ save_total_limit=3,
40
+ load_best_model_at_end=True if cfg.val_set_size > 0 else False,
41
+ ddp_find_unused_parameters=False if cfg.ddp else None,
42
+ group_by_length=cfg.group_by_length,
43
+ report_to="wandb" if cfg.use_wandb else None,
44
+ run_name=cfg.wandb_run_id if cfg.use_wandb else None,
45
+ **training_arguments_kwargs,
46
+ )
47
+
48
+ decay_parameters = get_parameter_names(model, [nn.LayerNorm])
49
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
50
+ optimizer_grouped_parameters = [
51
+ {
52
+ "params": [p for n, p in model.named_parameters() if n in decay_parameters],
53
+ "weight_decay": training_args.weight_decay,
54
+ },
55
+ {
56
+ "params": [
57
+ p for n, p in model.named_parameters() if n not in decay_parameters
58
+ ],
59
+ "weight_decay": 0.0,
60
+ },
61
+ ]
62
+
63
+ trainer_kwargs = {}
64
+
65
+ if cfg.load_in_8bit and not cfg.load_4bit:
66
+ optimizer = bnb.optim.Adam8bit(
67
+ optimizer_grouped_parameters,
68
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
69
+ eps=training_args.adam_epsilon,
70
+ lr=training_args.learning_rate,
71
+ )
72
+
73
+ if cfg.lr_scheduler == "one_cycle":
74
+ lr_scheduler_kwargs = (
75
+ cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
76
+ )
77
+ lr_scheduler = OneCycleLR(
78
+ optimizer,
79
+ cfg.learning_rate,
80
+ total_steps=total_num_steps,
81
+ **lr_scheduler_kwargs,
82
+ )
83
+ else:
84
+ lr_scheduler = transformers.get_cosine_schedule_with_warmup(
85
+ optimizer,
86
+ training_args.warmup_steps,
87
+ total_num_steps,
88
+ )
89
+ trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
90
+
91
+ # TODO on_save callback to sync checkpoints to GCP/AWS in background
92
+ if cfg.early_stopping_patience:
93
+ early_stop_cb = EarlyStoppingCallback(
94
+ cfg.early_stopping_patience,
95
+ )
96
+ trainer_kwargs["callbacks"] = [early_stop_cb]
97
+
98
+ trainer = transformers.Trainer(
99
+ model=model,
100
+ train_dataset=train_dataset,
101
+ eval_dataset=eval_dataset,
102
+ args=training_args,
103
+ data_collator=transformers.DataCollatorForSeq2Seq(
104
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
105
+ ),
106
+ **trainer_kwargs,
107
+ )
108
+
109
+ return trainer
src/axolotl/utils/wandb.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def setup_wandb_env_vars(cfg):
5
+ if cfg.wandb_project and len(cfg.wandb_project) > 0:
6
+ os.environ["WANDB_PROJECT"] = cfg.wandb_project
7
+ cfg.use_wandb = True
8
+ if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
9
+ os.environ["WANDB_WATCH"] = cfg.wandb_watch
10
+ if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
11
+ os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
12
+ if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
13
+ os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id