Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
·
2aa964c
1
Parent(s):
34968a1
fix
Browse files- llama_lora/lib/finetune.py +45 -36
llama_lora/lib/finetune.py
CHANGED
@@ -92,7 +92,7 @@ def train(
|
|
92 |
try:
|
93 |
additional_lora_config = json.loads(additional_lora_config)
|
94 |
except Exception as e:
|
95 |
-
raise ValueError(f"Could not parse
|
96 |
|
97 |
# for logging
|
98 |
finetune_args = {
|
@@ -262,16 +262,19 @@ def train(
|
|
262 |
|
263 |
# model = prepare_model_for_int8_training(model)
|
264 |
|
265 |
-
|
266 |
-
r
|
267 |
-
lora_alpha
|
268 |
-
target_modules
|
269 |
-
modules_to_save
|
270 |
-
lora_dropout
|
271 |
-
bias
|
272 |
-
task_type
|
273 |
-
|
274 |
-
|
|
|
|
|
|
|
275 |
model = get_peft_model(model, config)
|
276 |
if bf16:
|
277 |
model = model.to(torch.bfloat16)
|
@@ -336,36 +339,42 @@ def train(
|
|
336 |
model.is_parallelizable = True
|
337 |
model.model_parallel = True
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer
|
340 |
trainer = transformers.Trainer(
|
341 |
model=model,
|
342 |
train_dataset=train_data,
|
343 |
eval_dataset=val_data,
|
344 |
-
|
345 |
-
args=transformers.TrainingArguments(
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
warmup_steps=100,
|
350 |
-
num_train_epochs=num_train_epochs,
|
351 |
-
learning_rate=learning_rate,
|
352 |
-
fp16=fp16,
|
353 |
-
bf16=bf16,
|
354 |
-
logging_steps=logging_steps,
|
355 |
-
optim="adamw_torch",
|
356 |
-
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
357 |
-
save_strategy="steps",
|
358 |
-
eval_steps=save_steps if val_set_size > 0 else None,
|
359 |
-
save_steps=save_steps,
|
360 |
-
output_dir=output_dir,
|
361 |
-
save_total_limit=save_total_limit,
|
362 |
-
load_best_model_at_end=True if val_set_size > 0 else False,
|
363 |
-
ddp_find_unused_parameters=False if ddp else None,
|
364 |
-
group_by_length=group_by_length,
|
365 |
-
report_to="wandb" if use_wandb else None,
|
366 |
-
run_name=wandb_run_name if use_wandb else None,
|
367 |
-
**additional_training_arguments
|
368 |
-
),
|
369 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
370 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
371 |
),
|
|
|
92 |
try:
|
93 |
additional_lora_config = json.loads(additional_lora_config)
|
94 |
except Exception as e:
|
95 |
+
raise ValueError(f"Could not parse additional_lora_config: {e}")
|
96 |
|
97 |
# for logging
|
98 |
finetune_args = {
|
|
|
262 |
|
263 |
# model = prepare_model_for_int8_training(model)
|
264 |
|
265 |
+
lora_config_args = {
|
266 |
+
'r': lora_r,
|
267 |
+
'lora_alpha': lora_alpha,
|
268 |
+
'target_modules': lora_target_modules,
|
269 |
+
'modules_to_save': lora_modules_to_save,
|
270 |
+
'lora_dropout': lora_dropout,
|
271 |
+
'bias': "none",
|
272 |
+
'task_type': "CAUSAL_LM",
|
273 |
+
}
|
274 |
+
config = LoraConfig(**{
|
275 |
+
**lora_config_args,
|
276 |
+
**(additional_lora_config or {}),
|
277 |
+
})
|
278 |
model = get_peft_model(model, config)
|
279 |
if bf16:
|
280 |
model = model.to(torch.bfloat16)
|
|
|
339 |
model.is_parallelizable = True
|
340 |
model.model_parallel = True
|
341 |
|
342 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
|
343 |
+
training_args = {
|
344 |
+
'output_dir': output_dir,
|
345 |
+
'per_device_train_batch_size': micro_batch_size,
|
346 |
+
'gradient_checkpointing': gradient_checkpointing,
|
347 |
+
'gradient_accumulation_steps': gradient_accumulation_steps,
|
348 |
+
'warmup_steps': 100,
|
349 |
+
'num_train_epochs': num_train_epochs,
|
350 |
+
'learning_rate': learning_rate,
|
351 |
+
'fp16': fp16,
|
352 |
+
'bf16': bf16,
|
353 |
+
'logging_steps': logging_steps,
|
354 |
+
'optim': "adamw_torch",
|
355 |
+
'evaluation_strategy': "steps" if val_set_size > 0 else "no",
|
356 |
+
'save_strategy': "steps",
|
357 |
+
'eval_steps': save_steps if val_set_size > 0 else None,
|
358 |
+
'save_steps': save_steps,
|
359 |
+
'output_dir': output_dir,
|
360 |
+
'save_total_limit': save_total_limit,
|
361 |
+
'load_best_model_at_end': True if val_set_size > 0 else False,
|
362 |
+
'ddp_find_unused_parameters': False if ddp else None,
|
363 |
+
'group_by_length': group_by_length,
|
364 |
+
'report_to': "wandb" if use_wandb else None,
|
365 |
+
'run_name': wandb_run_name if use_wandb else None,
|
366 |
+
}
|
367 |
+
|
368 |
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer
|
369 |
trainer = transformers.Trainer(
|
370 |
model=model,
|
371 |
train_dataset=train_data,
|
372 |
eval_dataset=val_data,
|
373 |
+
tokenizer=tokenizer,
|
374 |
+
args=transformers.TrainingArguments(**{
|
375 |
+
**training_args,
|
376 |
+
**(additional_training_arguments or {})
|
377 |
+
}),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
379 |
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
380 |
),
|