zetavg commited on
Commit
2aa964c
·
1 Parent(s): 34968a1
Files changed (1) hide show
  1. llama_lora/lib/finetune.py +45 -36
llama_lora/lib/finetune.py CHANGED
@@ -92,7 +92,7 @@ def train(
92
  try:
93
  additional_lora_config = json.loads(additional_lora_config)
94
  except Exception as e:
95
- raise ValueError(f"Could not parse additional_training_arguments: {e}")
96
 
97
  # for logging
98
  finetune_args = {
@@ -262,16 +262,19 @@ def train(
262
 
263
  # model = prepare_model_for_int8_training(model)
264
 
265
- config = LoraConfig(
266
- r=lora_r,
267
- lora_alpha=lora_alpha,
268
- target_modules=lora_target_modules,
269
- modules_to_save=lora_modules_to_save,
270
- lora_dropout=lora_dropout,
271
- bias="none",
272
- task_type="CAUSAL_LM",
273
- **additional_lora_config,
274
- )
 
 
 
275
  model = get_peft_model(model, config)
276
  if bf16:
277
  model = model.to(torch.bfloat16)
@@ -336,36 +339,42 @@ def train(
336
  model.is_parallelizable = True
337
  model.model_parallel = True
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer
340
  trainer = transformers.Trainer(
341
  model=model,
342
  train_dataset=train_data,
343
  eval_dataset=val_data,
344
- # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
345
- args=transformers.TrainingArguments(
346
- per_device_train_batch_size=micro_batch_size,
347
- gradient_checkpointing=gradient_checkpointing,
348
- gradient_accumulation_steps=gradient_accumulation_steps,
349
- warmup_steps=100,
350
- num_train_epochs=num_train_epochs,
351
- learning_rate=learning_rate,
352
- fp16=fp16,
353
- bf16=bf16,
354
- logging_steps=logging_steps,
355
- optim="adamw_torch",
356
- evaluation_strategy="steps" if val_set_size > 0 else "no",
357
- save_strategy="steps",
358
- eval_steps=save_steps if val_set_size > 0 else None,
359
- save_steps=save_steps,
360
- output_dir=output_dir,
361
- save_total_limit=save_total_limit,
362
- load_best_model_at_end=True if val_set_size > 0 else False,
363
- ddp_find_unused_parameters=False if ddp else None,
364
- group_by_length=group_by_length,
365
- report_to="wandb" if use_wandb else None,
366
- run_name=wandb_run_name if use_wandb else None,
367
- **additional_training_arguments
368
- ),
369
  data_collator=transformers.DataCollatorForSeq2Seq(
370
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
371
  ),
 
92
  try:
93
  additional_lora_config = json.loads(additional_lora_config)
94
  except Exception as e:
95
+ raise ValueError(f"Could not parse additional_lora_config: {e}")
96
 
97
  # for logging
98
  finetune_args = {
 
262
 
263
  # model = prepare_model_for_int8_training(model)
264
 
265
+ lora_config_args = {
266
+ 'r': lora_r,
267
+ 'lora_alpha': lora_alpha,
268
+ 'target_modules': lora_target_modules,
269
+ 'modules_to_save': lora_modules_to_save,
270
+ 'lora_dropout': lora_dropout,
271
+ 'bias': "none",
272
+ 'task_type': "CAUSAL_LM",
273
+ }
274
+ config = LoraConfig(**{
275
+ **lora_config_args,
276
+ **(additional_lora_config or {}),
277
+ })
278
  model = get_peft_model(model, config)
279
  if bf16:
280
  model = model.to(torch.bfloat16)
 
339
  model.is_parallelizable = True
340
  model.model_parallel = True
341
 
342
+ # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
343
+ training_args = {
344
+ 'output_dir': output_dir,
345
+ 'per_device_train_batch_size': micro_batch_size,
346
+ 'gradient_checkpointing': gradient_checkpointing,
347
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
348
+ 'warmup_steps': 100,
349
+ 'num_train_epochs': num_train_epochs,
350
+ 'learning_rate': learning_rate,
351
+ 'fp16': fp16,
352
+ 'bf16': bf16,
353
+ 'logging_steps': logging_steps,
354
+ 'optim': "adamw_torch",
355
+ 'evaluation_strategy': "steps" if val_set_size > 0 else "no",
356
+ 'save_strategy': "steps",
357
+ 'eval_steps': save_steps if val_set_size > 0 else None,
358
+ 'save_steps': save_steps,
359
+ 'output_dir': output_dir,
360
+ 'save_total_limit': save_total_limit,
361
+ 'load_best_model_at_end': True if val_set_size > 0 else False,
362
+ 'ddp_find_unused_parameters': False if ddp else None,
363
+ 'group_by_length': group_by_length,
364
+ 'report_to': "wandb" if use_wandb else None,
365
+ 'run_name': wandb_run_name if use_wandb else None,
366
+ }
367
+
368
  # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer
369
  trainer = transformers.Trainer(
370
  model=model,
371
  train_dataset=train_data,
372
  eval_dataset=val_data,
373
+ tokenizer=tokenizer,
374
+ args=transformers.TrainingArguments(**{
375
+ **training_args,
376
+ **(additional_training_arguments or {})
377
+ }),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  data_collator=transformers.DataCollatorForSeq2Seq(
379
  tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
380
  ),