import math import os import time import pickle from contextlib import nullcontext import numpy as np import torch import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group from mamba_lm import MambaLM, MambaLMConfig import random import chess from lczero.backends import Weights, Backend, GameState # Default config values out_dir = 'out/play' save_interval = 50 wandb_project = 'chess-training' wandb_run_name = 'lc0-training' init_from = 'resume' # 'scratch', 'resume', 'anneal', or Mamba model name # Model parameters n_layer = 15 d_model = 256 dt_rank = 'auto' d_state = 16 vocab_size = 28 move_num_in_gamestate = False # wandb logging wandb_log = True wandb_project = 'mamba-rl' wandb_run_name = 'mamba_run' # Load openings file with open("openings.csv", "r") as file: lines = file.readlines()[1:] # Skip header opening_lines = lines # Optimizer settings learning_rate = 1e-7 #7.25e-7 min_lr = 1e-8 # 1.75e-8 warmup_iters = 600 lr_decay_iters = len(opening_lines) weight_decay = 1e-2 #5e-3 beta1 = 0.905 #0.915 beta2 = 0.965 #0.95 grad_clip = 0.5 #0.25 min_grad_clip = 1e-3 #1e-3 max_grad_clip = 0.45 #0.45 auto_clip = True grad_clip_start_size = 150 grad_clip_max_size = 600 grad_clip_percentile = 9 # Game play / loss calculation settings top_k = 2 # 2 top_k_adj_moves = 40 #999 #35 max_illegal_moves = 8 #2 max_moves = 87 update_freq = 3 #1 # How often to do a backward pass flush_every = 1 move_reward_scale_factor = 4.0 # 2.125 # scales down the move reward so it's not so dramatic / so that illegal moves (reward -1) are more dramatic by comparison to bad moves decrease_factor = 0.75 # Bonus for winning (1/x is penalty for losing) window_size = 300 # DDP settings backend = 'nccl' # System device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32' compile = False # Set to True if using PyTorch 2.0 config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] #exec(open('configurator.py').read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # will be useful for logging # Initialize lc0 engines lc0_weights_opponent = Weights("./lc0/build/release/11258-32x4-se.pb.gz") lc0_backend_opponent = Backend(weights=lc0_weights_opponent) lc0_weights_evaluator = Weights("./lc0/build/release/11258-48x5-se.pb.gz") lc0_backend_evaluator = lc0_backend_opponent #Backend(weights=lc0_weights_evaluator) # Load tokenizer and decode function if move_num_in_gamestate: meta_path = os.path.join(os.path.join('data', 'chess'), 'meta.pkl') with open(meta_path, "rb") as f: meta = pickle.load(f) stoi, itos = meta["stoi"], meta["itos"] vocab_size = meta['vocab_size'] encode = lambda s: [stoi[c] for c in s] decode = lambda l: "".join([itos[i] for i in l]) else: stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27} itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='} for s in stoi: assert itos[stoi[s]] == s vocab_size = len(stoi) print(f"Vocab size {vocab_size}") encode = lambda s: [stoi[c] for c in s.replace('-', '')] decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O") # Initialize Mamba model mamba_config = MambaLMConfig( d_model=d_model, n_layers=n_layer, dt_rank=dt_rank, d_state=d_state, vocab_size=vocab_size # Adjust based on your dataset ) model = MambaLM(mamba_config) model.to(device) # Optimizer and GradScaler optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16') # Compile the model if using PyTorch 2.0 if compile: print("compiling the model... (takes a ~minute)") model = torch.compile(model) ddp = int(os.environ.get('RANK', -1)) != -1 # Wrap model in DDP container if necessary if ddp: model = DDP(model, device_ids=[ddp_local_rank]) win_rate_window = [] win_only_rate_window = [] # Load checkpoint if resuming training if init_from == 'resume': print(f"Resuming training from {out_dir}") ckpt_path = os.path.join(out_dir, 'ckpt.pt') checkpoint = torch.load(ckpt_path, map_location=device) mamba_config = checkpoint['model_args'] state_dict = checkpoint['model'] # fix the keys of the state dictionary :( # honestly no idea how checkpoints sometimes get this prefix, have to debug more unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint['optimizer']) iter_num = checkpoint['iter_num'] games_played = checkpoint['games_seen'] opening_line_index = checkpoint.get('opening_line_index', 0) win_rate_window = checkpoint.get('win_rate_window', []) win_only_rate_window = checkpoint.get('win_only_rate_window', []) best_wr = checkpoint.get('best_wr', 0.0) best_wor = checkpoint.get('best_wor', 0.0) if auto_clip: grad_clip = checkpoint['config']['grad_clip'] config['grad_clip'] = grad_clip grad_norm_history = checkpoint.get('grad_norm_history', []) else: grad_norm_history = [] else: best_wr = 0.0 best_wor = 0.0 grad_norm_history = [] games_played = 0 iter_num = 0 opening_line_index = 0 if auto_clip: grad_clip = 0 config['grad_clip'] = 0 def get_model_move(game_state, top_k): model.train() # Ensure the model is in training mode encoded_prompt = encode(game_state) input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=device) have_non_space = False logits_list = [] # Collect logits for analysis and potential loss calculation for _ in range(8): logits = model(input_ids)[0, -1, :] # Logits for the last predicted token # We're using top-k more as a VRAM control, not a decision enhacing tool if top_k is not None and top_k < logits.size(-1): logits, indices = torch.topk(logits, top_k) probs = torch.nn.functional.softmax(logits, dim=-1) next_token_id = indices[torch.multinomial(probs, 1)] else: probs = torch.nn.functional.softmax(logits, dim=-1) next_token_id = torch.multinomial(probs, num_samples=1) if have_non_space and (next_token_id == 0 or next_token_id==4): break else: have_non_space = True input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1) logits_list.append(logits) del logits, probs # Decode the sequence to extract the move model_response = decode(input_ids.squeeze(0).tolist()) try: move = model_response[len(game_state):].split(";")[0].split()[0] # Extract the first move except IndexError: move = None return move, torch.stack(logits_list) if len(logits_list) > 0 else None def get_lc0_move(board, backend): gamestate = GameState(fen=board.fen()) input_planes = gamestate.as_input(backend) result = backend.evaluate(input_planes)[0] moves = gamestate.moves() policy_indices = gamestate.policy_indices() move_probs = np.array(result.p_softmax(*policy_indices)) try: best_move_idx = move_probs.argmax() except: return None best_move = moves[best_move_idx] return chess.Move.from_uci(best_move) def evaluate_position(fen, backend): gamestate = GameState(fen=fen) result = backend.evaluate(gamestate.as_input(backend))[0] return result.q() def reward_from_eval(before_eval, after_eval): diff = after_eval - before_eval return diff / (move_reward_scale_factor + abs(diff)) def backward_pass(loss): global grad_norm_history # Backward pass scaler.scale(loss).backward() # clip the gradient if grad_clip != 0.0 or auto_clip: scaler.unscale_(optimizer) total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 0.1) # The 0 check is for auto_clip enabled but not enough history grad_norm_history.append(total_norm.item()) grad_norm_history = grad_norm_history[-grad_clip_max_size:] scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) def play_game(): global top_k optimizer.zero_grad(set_to_none=True) torch.cuda.empty_cache() board = chess.Board() total_loss = 0 illegal_moves = 0 move_count = 0 moves_since_backward = 0 tot_reward = 0 # Load opening from openings.csv tokens = [m.split(".")[-1] if "." in m else m for m in opening_line.split()] [board.push_san(m) for m in tokens] if move_num_in_gamestate: game_state = opening_line.rstrip() + " " else: game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in opening_line.split()]) fail = False while not board.is_game_over(): before_eval = evaluate_position(board.fen(), lc0_backend_evaluator) game_state += f"{board.fullmove_number if move_num_in_gamestate else ''}." model_move, logits = get_model_move(game_state, top_k) move_reward = -1 if model_move is None or logits is None: illegal_moves += 1 pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) if pinch_hit_move is None: print("Failed game (lc0 couldn't find pinch-hit move)") fail = True tot_reward += move_reward move_count += 1 break game_state += f"{board.san(pinch_hit_move)} " board.push(pinch_hit_move) else: try: #print(model_move) board.push(board.parse_san(model_move)) game_state += f"{model_move} " except: illegal_moves += 1 pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) if pinch_hit_move is None: print("Failed game (lc0 couldn't find pinch-hit move)") fail = True tot_reward += move_reward move_count += 1 break game_state += f"{board.san(pinch_hit_move)} " board.push(pinch_hit_move) else: if not board.is_valid(): board.pop() illegal_moves += 1 pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) if pinch_hit_move is None: print("Failed game (lc0 couldn't find pinch-hit move)") fail = True tot_reward += move_reward move_count += 1 break game_state += f"{board.san(pinch_hit_move)} " board.push(pinch_hit_move) else: after_eval = -evaluate_position(board.fen(), lc0_backend_evaluator) move_reward = reward_from_eval(before_eval, after_eval) tot_reward += move_reward if not board.is_game_over(): black_move = get_lc0_move(board, lc0_backend_opponent) if black_move is None: print("Failed game (lc0 couldn't find black move)") fail = True move_count += 1 break game_state += f"{board.san(black_move)} " board.push(black_move) if logits is not None: total_loss += torch.sum(torch.nn.functional.log_softmax(logits, dim=-1) * move_reward) logits_none = logits is None del logits moves_since_backward += 1 if move_count % update_freq == 0 and not board.is_game_over() and not logits_none: backward_pass(total_loss / moves_since_backward) total_loss = 0.0 # Reset cumulative loss after update moves_since_backward = 0 move_count += 1 if move_count == top_k_adj_moves: top_k = top_k - 1 if move_count >= max_moves: break if move_count % flush_every == 0: torch.cuda.empty_cache() if move_count >= top_k_adj_moves: top_k = top_k + 1 # Scale loss based on game result and illegal moves avg_reward = tot_reward / move_count #print(f'Avg reward {avg_reward} = {tot_reward} / {move_count}') scale_factor = torch.tensor([1.0], device=device) if move_count >= max_moves: result = "1/2-1/2" elif fail: result = "*" else: result = board.result() total_loss = total_loss / moves_since_backward if result == "0-1": # Black wins # Increase the loss for a loss, if the reward is negative (if the loss is positive) scale_factor = torch.tensor([1.0 / decrease_factor], device=device) if avg_reward < 0 and illegal_moves <= max_illegal_moves else scale_factor #print(f'Black win, scale factor adjusted to {scale_factor} (avg award<0 and illegal less max {avg_reward < 0 and illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}') elif result == "1-0": # White wins wdf = decrease_factor / 2.0 if avg_reward <= 0 else 1.0 / decrease_factor #print(f'White win - adjusted decrease factor {wdf}') # Don't update as much for (real) wins. Also change the result so our win_rate isn't inflated. if illegal_moves == 0: scale_factor = torch.tensor([wdf], device=device) #print(f'White win, scale factor adjusted to {scale_factor} (0 illegal moves)') elif illegal_moves <= max_illegal_moves: scale_factor = torch.tensor([(1 + wdf) / 2], device=device) #print(f'White win, scale factor adjusted to {scale_factor} ({0 < illegal_moves <= max_illegal_moves}), illegal vs max {illegal_moves} vs {max_illegal_moves}') result = "1/2-1/2" else: result = "0-1" # No adjustment to scale_factor if total_loss.numel(): try: backward_pass(total_loss * scale_factor) except: print("Failed game (final backward pass, result not effected)") total_loss = 0.0 #print(f'Scale factor {scale_factor.item()}') return avg_reward / scale_factor.item(), result, illegal_moves, move_count def get_lr(it): # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters # 2) if it > lr_decay_iters, return min learning rate if it > lr_decay_iters: return min_lr # 3) in between, use cosine decay down to min learning rate decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr) # Training loop if wandb_log: import wandb wandb.init(project=wandb_project, name=wandb_run_name, config=config) while True: t0 = time.time() lr = get_lr(iter_num) for param_group in optimizer.param_groups: param_group['lr'] = lr opening_line = opening_lines[opening_line_index] if iter_num > 0 and iter_num % save_interval == 0: if auto_clip and len(grad_norm_history) >= grad_clip_start_size: grad_clip = max(min(np.percentile(grad_norm_history, grad_clip_percentile), max_grad_clip), min_grad_clip) config['grad_clip'] = grad_clip print(f"Auto adjusted grad_clip to {grad_clip}") #print(f"Game {games_played}: Loss {game_reward:.4f}, Illegal moves {illegal_moves}, Win rate {win_rate:.3f}") if wandb_log: wandb.log({ "etc/iter": iter_num, "etc/lr": lr, "etc/grad_clip": grad_clip, "etc/games_played": games_played, }) # Save checkpoint raw_model = model.module if ddp else model checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': mamba_config, 'iter_num': iter_num, "games_seen": games_played, 'config': config, 'opening_line_index': opening_line_index, 'grad_norm_history': grad_norm_history, 'win_rate_window': win_rate_window, 'win_only_rate_window': win_only_rate_window, 'best_wr': best_wr, 'best_wor': best_wor } print(f"saving checkpoint to {out_dir}\n") torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) # Play a game against lc0 engine game_reward, result, illegal_moves, move_count = play_game() games_played += 1 # Backward passes happen in play_game # Log game result and update win rate window t1 = time.time() dt = t1 - t0 t0 = t1 score = 0.5 if result == "1-0": score = 1 elif result == "0-1": score = 0 if result != "*": win_rate_window.append(score) win_rate_window = win_rate_window[-window_size:] win_rate = sum(win_rate_window) / len(win_rate_window) win_only_rate_window.append(int(score)) #int to discard draws win_only_rate_window = win_only_rate_window[-window_size:] win_only_rate = float(sum(win_only_rate_window)) / len(win_only_rate_window) if win_rate > best_wr: best_wr = win_rate raw_model = model.module if ddp else model checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': mamba_config, 'iter_num': iter_num, "games_seen": games_played, 'config': config, 'opening_line_index': opening_line_index, 'grad_norm_history': grad_norm_history, 'win_rate_window': win_rate_window, 'best_wr': best_wr, 'best_wor': best_wor } print(f"saving checkpoint to {out_dir}\n") torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wr{best_wr}.pt')) elif win_only_rate > best_wor: best_wor = win_only_rate raw_model = model.module if ddp else model checkpoint = { 'model': raw_model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_args': mamba_config, 'iter_num': iter_num, "games_seen": games_played, 'config': config, 'opening_line_index': opening_line_index, 'grad_norm_history': grad_norm_history, 'win_rate_window': win_rate_window, 'best_wr': best_wr, 'best_wor': best_wor } print(f"saving checkpoint to {out_dir}\n") torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wor{best_wor}.pt')) best_wor = max(best_wor, win_only_rate) print(f"Game {games_played} ({iter_num}, {(iter_num / len(opening_lines)) * 100.0:.3f}%): Score {score}, Reward {game_reward:.4f}, Illegal moves {illegal_moves} ({illegal_moves / move_count:.3%}), Total moves {move_count}, Win rate {win_rate:.3f}, Win only rate {win_only_rate:.3f}, time {dt * 1000:.2f}ms") if wandb_log: wandb.log({ "etc/iter": iter_num, "etc/lr": lr, "etc/grad_norm_mean": np.mean(grad_norm_history) if grad_norm_history else -1, "etc/grad_zero_pct": float(np.count_nonzero(grad_norm_history==0))/len(grad_norm_history) if grad_norm_history else -1, "etc/games_played": games_played, "eval/game_reward": game_reward, "eval/illegal_move_pct": illegal_moves / move_count, "eval/move_ct": move_count, "eval/win_rate": win_rate, "eval/win_only_rate": win_only_rate, }) iter_num += 1 opening_line_index += 1 # Termination condition if opening_line_index >= len(opening_lines): break if ddp: destroy_process_group()