HaileyStorm commited on
Commit
62e35f0
1 Parent(s): ec6ad99

Update chess-mamba-vs-xformer/train_bygame.py

Browse files
chess-mamba-vs-xformer/train_bygame.py CHANGED
@@ -394,7 +394,10 @@ while True:
394
  if iter_num % eval_interval == 0 and master_process and local_iter_num > 0:
395
  torch.cuda.empty_cache()
396
  losses = estimate_loss()
397
- print(f"\ngame {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): 'val' loss {losses['val']:.4f}")
 
 
 
398
  if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
399
  grad_clip_prev = grad_clip
400
  grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
@@ -481,7 +484,10 @@ while True:
481
  # get loss as float. note: this is a CPU-GPU sync point
482
  # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
483
  lossf = loss.item() * gradient_accumulation_steps
484
- print(f"game {games_seen} ({iter_num}, {(iter_num / max_iters)*100.0:.3f}%): loss {lossf:.4f}, time {dt*1000:.2f}ms")
 
 
 
485
  if wandb_log:
486
  wandb.log({
487
  "etc/iter": iter_num,
 
394
  if iter_num % eval_interval == 0 and master_process and local_iter_num > 0:
395
  torch.cuda.empty_cache()
396
  losses = estimate_loss()
397
+ if init_from == 'anneal':
398
+ print(f"\ngame {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): 'val' loss {losses['val']:.4f}")
399
+ else:
400
+ print(f"\ngame {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): 'val' loss {losses['val']:.4f}")
401
  if auto_clip and len(grad_norm_history) >= grad_clip_start_size:
402
  grad_clip_prev = grad_clip
403
  grad_clip = np.percentile(grad_norm_history, grad_clip_percentile)
 
484
  # get loss as float. note: this is a CPU-GPU sync point
485
  # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
486
  lossf = loss.item() * gradient_accumulation_steps
487
+ if init_from == 'anneal':
488
+ print(f"game {games_seen} ({iter_num}, {(iter_num-anneal_start_iters) / anneal_decay_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms")
489
+ else:
490
+ print(f"game {games_seen} ({iter_num}, {iter_num / max_iters:.3%}): loss {lossf:.4f}, time {dt*1000:.2f}ms")
491
  if wandb_log:
492
  wandb.log({
493
  "etc/iter": iter_num,