zetavg commited on
Commit
05ad97e
·
1 Parent(s): 8e2e7b5

finetune: log trainable% to wandb

Browse files
Files changed (1) hide show
  1. llama_lora/lib/finetune.py +11 -0
llama_lora/lib/finetune.py CHANGED
@@ -275,7 +275,18 @@ def train(
275
  raise ValueError(f"Checkpoint {checkpoint_name} not found")
276
 
277
  # Be more transparent about the % of trainable params.
 
 
 
 
 
 
 
 
 
278
  model.print_trainable_parameters()
 
 
279
 
280
  if val_set_size > 0:
281
  train_val = train_data.train_test_split(
 
275
  raise ValueError(f"Checkpoint {checkpoint_name} not found")
276
 
277
  # Be more transparent about the % of trainable params.
278
+ trainable_params = 0
279
+ all_param = 0
280
+ for _, param in model.named_parameters():
281
+ all_param += param.numel()
282
+ if param.requires_grad:
283
+ trainable_params += param.numel()
284
+ print(
285
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param} (calculated)"
286
+ )
287
  model.print_trainable_parameters()
288
+ if use_wandb and wandb:
289
+ wandb.config.update({"model": { "all_param": all_param, "trainable_params": trainable_params, "trainable%": 100 * trainable_params / all_param }})
290
 
291
  if val_set_size > 0:
292
  train_val = train_data.train_test_split(