HaileyStorm commited on
Commit
5928118
1 Parent(s): 4fcec20

Update chess-mamba-vs-xformer/train_bygame.py

Browse files
chess-mamba-vs-xformer/train_bygame.py CHANGED
@@ -394,6 +394,9 @@ if init_from == 'scratch':
394
  print(f"saving checkpoint to {out_dir}\n")
395
  torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
396
 
 
 
 
397
  t0 = time.time()
398
  while True:
399
  # Determine and set the learning rate for this iteration
@@ -402,7 +405,7 @@ while True:
402
  param_group['lr'] = lr
403
 
404
  # Evaluate the loss on train/val sets and write checkpoints
405
- if master_process and ((iter_num % eval_interval == 0 and local_iter_num > 0) or abs(games_seen - 12652800) <= 151 or abs(games_seen - 22275000) <= 151 or abs(games_seen - 11536000) <= 151 or abs(games_seen - 16250000) <= 151 or abs(games_seen - 18000000) <= 151 or abs(games_seen - 19690000) <= 151 or abs(games_seen - 22005050) <= 151 or abs(tokens_seen_padded - 7798839804) <= 46238):
406
  torch.cuda.empty_cache()
407
  losses = estimate_loss()
408
  if init_from == 'anneal':
@@ -453,7 +456,7 @@ while True:
453
  if losses['val'] < best_val_loss: # Temporary / only good after it's settled
454
  best_val_loss = losses['val']
455
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
456
- elif current_nearest_multiple != last_crossed_multiple or abs(games_seen - 12652800) <= 151 or abs(games_seen - 22275000) <= 151 or abs(games_seen - 11536000) <= 151 or abs(games_seen - 16250000) <= 151 or abs(games_seen - 18000000) <= 151 or abs(games_seen - 19690000) <= 151 or abs(games_seen - 22005050) <= 151 or abs(tokens_seen_padded - 7798839804) <= 46238: # elif so we don't double up
457
  last_crossed_multiple = current_nearest_multiple
458
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}g_{tokens_seen_padded}t.pt'))
459
 
 
394
  print(f"saving checkpoint to {out_dir}\n")
395
  torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
396
 
397
+ GAMES_SEEN_CHECKPOINTS = [12652800, 22275000, 11536000, 16250000, 18000000, 19690000, 22005050]
398
+ TOKENS_SEEN_PADDED_CHECKPOINTS = [7798839804]
399
+
400
  t0 = time.time()
401
  while True:
402
  # Determine and set the learning rate for this iteration
 
405
  param_group['lr'] = lr
406
 
407
  # Evaluate the loss on train/val sets and write checkpoints
408
+ if master_process and ((iter_num % eval_interval == 0 and local_iter_num > 0) or any(abs(games_seen - checkpoint) <= 151 for checkpoint in GAMES_SEEN_CHECKPOINTS) or any(abs(tokens_seen_padded - checkpoint) <= 46238 for checkpoint in TOKENS_SEEN_PADDED_CHECKPOINTS)):
409
  torch.cuda.empty_cache()
410
  losses = estimate_loss()
411
  if init_from == 'anneal':
 
456
  if losses['val'] < best_val_loss: # Temporary / only good after it's settled
457
  best_val_loss = losses['val']
458
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}b.pt'))
459
+ elif current_nearest_multiple != last_crossed_multiple or any(abs(games_seen - checkpoint) <= 151 for checkpoint in GAMES_SEEN_CHECKPOINTS) or any(abs(tokens_seen_padded - checkpoint) <= 46238 for checkpoint in TOKENS_SEEN_PADDED_CHECKPOINTS): # elif so we don't double up
460
  last_crossed_multiple = current_nearest_multiple
461
  torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{int(games_seen)}g_{tokens_seen_padded}t.pt'))
462