winglian commited on
Commit
12de7b7
1 Parent(s): d1aed4c

cleanup, prep for 4bit quant support

Browse files
Files changed (3) hide show
  1. README.md +21 -1
  2. scripts/finetune.py +18 -6
  3. setup.cfg +3 -0
README.md CHANGED
@@ -30,4 +30,24 @@ shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
30
 
31
  - Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
32
  - Install python dependencies `pip3 install -r requirements.txt`
33
- - Train! `python3 scripts/finetune.py`, make sure to choose the correct YAML config file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  - Create a new or update the existing YAML config (config/pythia_1_2B_alpaca.yml)[config/pythia_1_2B_alpaca.yml]
32
  - Install python dependencies `pip3 install -r requirements.txt`
33
+ - Configure accelerate `accelerate launch` or update `~/.cache/huggingface/accelerate/default_config.yaml`
34
+
35
+ ```yaml
36
+ compute_environment: LOCAL_MACHINE
37
+ distributed_type: MULTI_GPU
38
+ downcast_bf16: 'no'
39
+ gpu_ids: all
40
+ machine_rank: 0
41
+ main_training_function: main
42
+ mixed_precision: bf16
43
+ num_machines: 1
44
+ num_processes: 4
45
+ rdzv_backend: static
46
+ same_network: true
47
+ tpu_env: []
48
+ tpu_use_cluster: false
49
+ tpu_use_sudo: false
50
+ use_cpu: false
51
+ ```
52
+
53
+ - Train! `accelerate launch scripts/finetune.py`, make sure to choose the correct YAML config file
scripts/finetune.py CHANGED
@@ -68,26 +68,27 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora", infe
68
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
69
  replace_llama_attn_with_flash_attn()
70
 
 
71
  try:
72
  if "llama" in base_model:
73
  model = LlamaForCausalLM.from_pretrained(
74
  base_model,
75
  load_in_8bit=cfg.load_in_8bit,
76
- torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
77
  device_map=cfg.device_map,
78
  )
79
  else:
80
  model = getattr(transformers, model_type).from_pretrained(
81
  base_model,
82
  load_in_8bit=cfg.load_in_8bit,
83
- torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
84
  device_map=cfg.device_map,
85
  )
86
  except:
87
  model = AutoModelForCausalLM.from_pretrained(
88
  base_model,
89
  load_in_8bit=cfg.load_in_8bit,
90
- torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
91
  device_map=cfg.device_map,
92
  )
93
 
@@ -235,7 +236,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
235
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
236
 
237
  training_arguments_kwargs = {}
238
- training_arguments_kwargs["bf16"] = cfg.bf16
 
 
 
239
  training_arguments_kwargs["tf32"] = cfg.tf32
240
  training_arguments_kwargs["warmup_steps"] = warmup_steps
241
  training_arguments_kwargs["logging_steps"] = logging_steps
@@ -256,10 +260,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
256
  group_by_length=cfg.group_by_length,
257
  report_to="wandb" if cfg.use_wandb else None,
258
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
 
259
  **training_arguments_kwargs,
260
  )
261
 
262
- trainer_kwargs = {}
263
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
264
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
265
  optimizer_grouped_parameters = [
@@ -282,13 +286,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
282
  lr=training_args.learning_rate,
283
  )
284
 
 
285
  lr_scheduler = transformers.get_cosine_schedule_with_warmup(
286
  adam_bnb_optim,
287
  training_args.warmup_steps,
288
  total_num_steps,
289
  )
290
- trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
291
 
 
292
  if cfg.early_stopping_patience:
293
  early_stop_cb = EarlyStoppingCallback(
294
  cfg.early_stopping_patience,
@@ -300,6 +305,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
300
  train_dataset=train_dataset,
301
  eval_dataset=eval_dataset,
302
  args=training_args,
 
303
  data_collator=transformers.DataCollatorForSeq2Seq(
304
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
305
  ),
@@ -342,6 +348,12 @@ def train(
342
  cfg.gradient_accumulation_steps // cfg.world_size
343
  )
344
  setup_wandb_env_vars(cfg)
 
 
 
 
 
 
345
 
346
  # Load the model and tokenizer
347
  model, tokenizer, lora_config = load_model(
 
68
  from axolotl.flash_attn import replace_llama_attn_with_flash_attn
69
  replace_llama_attn_with_flash_attn()
70
 
71
+ torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
72
  try:
73
  if "llama" in base_model:
74
  model = LlamaForCausalLM.from_pretrained(
75
  base_model,
76
  load_in_8bit=cfg.load_in_8bit,
77
+ torch_dtype=torch_dtype,
78
  device_map=cfg.device_map,
79
  )
80
  else:
81
  model = getattr(transformers, model_type).from_pretrained(
82
  base_model,
83
  load_in_8bit=cfg.load_in_8bit,
84
+ torch_dtype=torch_dtype,
85
  device_map=cfg.device_map,
86
  )
87
  except:
88
  model = AutoModelForCausalLM.from_pretrained(
89
  base_model,
90
  load_in_8bit=cfg.load_in_8bit,
91
+ torch_dtype=torch_dtype,
92
  device_map=cfg.device_map,
93
  )
94
 
 
236
  save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
237
 
238
  training_arguments_kwargs = {}
239
+ if cfg.bf16 == "full":
240
+ training_arguments_kwargs["bf16_full_eval"] = True
241
+ else:
242
+ training_arguments_kwargs["bf16"] = cfg.bf16
243
  training_arguments_kwargs["tf32"] = cfg.tf32
244
  training_arguments_kwargs["warmup_steps"] = warmup_steps
245
  training_arguments_kwargs["logging_steps"] = logging_steps
 
260
  group_by_length=cfg.group_by_length,
261
  report_to="wandb" if cfg.use_wandb else None,
262
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
263
+ gradient_checkpointing=cfg.gradient_checkpointing,
264
  **training_arguments_kwargs,
265
  )
266
 
 
267
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
268
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
269
  optimizer_grouped_parameters = [
 
286
  lr=training_args.learning_rate,
287
  )
288
 
289
+ # TODO optionally use torch.optim.OneCycleLR
290
  lr_scheduler = transformers.get_cosine_schedule_with_warmup(
291
  adam_bnb_optim,
292
  training_args.warmup_steps,
293
  total_num_steps,
294
  )
 
295
 
296
+ trainer_kwargs = {}
297
  if cfg.early_stopping_patience:
298
  early_stop_cb = EarlyStoppingCallback(
299
  cfg.early_stopping_patience,
 
305
  train_dataset=train_dataset,
306
  eval_dataset=eval_dataset,
307
  args=training_args,
308
+ optimizers=(adam_bnb_optim, lr_scheduler),
309
  data_collator=transformers.DataCollatorForSeq2Seq(
310
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
311
  ),
 
348
  cfg.gradient_accumulation_steps // cfg.world_size
349
  )
350
  setup_wandb_env_vars(cfg)
351
+ if cfg.device == "mps":
352
+ cfg.load_in_8bit = False
353
+ cfg.tf32 = False
354
+ if cfg.bf16:
355
+ cfg.fp16 = True
356
+ cfg.bf16 = False
357
 
358
  # Load the model and tokenizer
359
  model, tokenizer, lora_config = load_model(
setup.cfg CHANGED
@@ -28,3 +28,6 @@ install_requires =
28
  [options.packages.find]
29
  where = src
30
 
 
 
 
 
28
  [options.packages.find]
29
  where = src
30
 
31
+ [options.extras_require]
32
+ gptq_cuda = alpaca_lora_4bit[cuda] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[cuda]
33
+ gptq_triton = alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip#egg=alpaca_lora_4bit[triton]