Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
·
05ad97e
1
Parent(s):
8e2e7b5
finetune: log trainable% to wandb
Browse files- llama_lora/lib/finetune.py +11 -0
llama_lora/lib/finetune.py
CHANGED
@@ -275,7 +275,18 @@ def train(
|
|
275 |
raise ValueError(f"Checkpoint {checkpoint_name} not found")
|
276 |
|
277 |
# Be more transparent about the % of trainable params.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
model.print_trainable_parameters()
|
|
|
|
|
279 |
|
280 |
if val_set_size > 0:
|
281 |
train_val = train_data.train_test_split(
|
|
|
275 |
raise ValueError(f"Checkpoint {checkpoint_name} not found")
|
276 |
|
277 |
# Be more transparent about the % of trainable params.
|
278 |
+
trainable_params = 0
|
279 |
+
all_param = 0
|
280 |
+
for _, param in model.named_parameters():
|
281 |
+
all_param += param.numel()
|
282 |
+
if param.requires_grad:
|
283 |
+
trainable_params += param.numel()
|
284 |
+
print(
|
285 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param} (calculated)"
|
286 |
+
)
|
287 |
model.print_trainable_parameters()
|
288 |
+
if use_wandb and wandb:
|
289 |
+
wandb.config.update({"model": { "all_param": all_param, "trainable_params": trainable_params, "trainable%": 100 * trainable_params / all_param }})
|
290 |
|
291 |
if val_set_size > 0:
|
292 |
train_val = train_data.train_test_split(
|