import math |
import os |
import time |
import pickle |
from contextlib import nullcontext |
import numpy as np |
import torch |
import torch.nn.functional as F |
from torch.nn.parallel import DistributedDataParallel as DDP |
from torch.distributed import init_process_group, destroy_process_group |
from mamba_lm import MambaLM, MambaLMConfig |
import random |
import chess |
from lczero.backends import Weights, Backend, GameState |
out_dir = 'out/play' |
save_interval = 50 |
wandb_project = 'chess-training' |
wandb_run_name = 'lc0-training' |
init_from = 'resume' |
n_layer = 15 |
d_model = 256 |
dt_rank = 'auto' |
d_state = 16 |
vocab_size = 28 |
move_num_in_gamestate = False |
wandb_log = True |
wandb_project = 'mamba-rl' |
wandb_run_name = 'mamba_run' |
with open("openings.csv", "r") as file: |
lines = file.readlines()[1:] |
opening_lines = lines |
learning_rate = 1e-7 |
min_lr = 1e-8 |
warmup_iters = 600 |
lr_decay_iters = len(opening_lines) |
weight_decay = 1e-2 |
beta1 = 0.905 |
beta2 = 0.965 |
grad_clip = 0.5 |
min_grad_clip = 1e-3 |
max_grad_clip = 0.45 |
auto_clip = True |
grad_clip_start_size = 150 |
grad_clip_max_size = 600 |
grad_clip_percentile = 9 |
top_k = 2 |
top_k_adj_moves = 40 |
max_illegal_moves = 8 |
max_moves = 87 |
update_freq = 3 |
flush_every = 1 |
move_reward_scale_factor = 4.0 |
decrease_factor = 0.75 |
window_size = 300 |
backend = 'nccl' |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32' |
compile = False |
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] |
config = {k: globals()[k] for k in config_keys} |
lc0_weights_opponent = Weights("./lc0/build/release/11258-32x4-se.pb.gz") |
lc0_backend_opponent = Backend(weights=lc0_weights_opponent) |
lc0_weights_evaluator = Weights("./lc0/build/release/11258-48x5-se.pb.gz") |
lc0_backend_evaluator = lc0_backend_opponent |
if move_num_in_gamestate: |
meta_path = os.path.join(os.path.join('data', 'chess'), 'meta.pkl') |
with open(meta_path, "rb") as f: |
meta = pickle.load(f) |
stoi, itos = meta["stoi"], meta["itos"] |
vocab_size = meta['vocab_size'] |
encode = lambda s: [stoi[c] for c in s] |
decode = lambda l: "".join([itos[i] for i in l]) |
else: |
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27} |
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='} |
for s in stoi: |
assert itos[stoi[s]] == s |
vocab_size = len(stoi) |
print(f"Vocab size {vocab_size}") |
encode = lambda s: [stoi[c] for c in s.replace('-', '')] |
decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O") |
mamba_config = MambaLMConfig( |
d_model=d_model, |
n_layers=n_layer, |
dt_rank=dt_rank, |
d_state=d_state, |
vocab_size=vocab_size |
) |
model = MambaLM(mamba_config) |
model.to(device) |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) |
scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16') |
if compile: |
print("compiling the model... (takes a ~minute)") |
model = torch.compile(model) |
ddp = int(os.environ.get('RANK', -1)) != -1 |
if ddp: |
model = DDP(model, device_ids=[ddp_local_rank]) |
win_rate_window = [] |
win_only_rate_window = [] |
if init_from == 'resume': |
print(f"Resuming training from {out_dir}") |
ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
checkpoint = torch.load(ckpt_path, map_location=device) |
mamba_config = checkpoint['model_args'] |
state_dict = checkpoint['model'] |
unwanted_prefix = '_orig_mod.' |
for k, v in list(state_dict.items()): |
if k.startswith(unwanted_prefix): |
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
model.load_state_dict(state_dict) |
optimizer.load_state_dict(checkpoint['optimizer']) |
iter_num = checkpoint['iter_num'] |
games_played = checkpoint['games_seen'] |
opening_line_index = checkpoint.get('opening_line_index', 0) |
win_rate_window = checkpoint.get('win_rate_window', []) |
win_only_rate_window = checkpoint.get('win_only_rate_window', []) |
best_wr = checkpoint.get('best_wr', 0.0) |
best_wor = checkpoint.get('best_wor', 0.0) |
if auto_clip: |
grad_clip = checkpoint['config']['grad_clip'] |
config['grad_clip'] = grad_clip |
grad_norm_history = checkpoint.get('grad_norm_history', []) |
else: |
grad_norm_history = [] |
else: |
best_wr = 0.0 |
best_wor = 0.0 |
grad_norm_history = [] |
games_played = 0 |
iter_num = 0 |
opening_line_index = 0 |
if auto_clip: |
grad_clip = 0 |
config['grad_clip'] = 0 |
def get_model_move(game_state, top_k): |
model.train() |
encoded_prompt = encode(game_state) |
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=device) |
have_non_space = False |
logits_list = [] |
for _ in range(8): |
logits = model(input_ids)[0, -1, :] |
if top_k is not None and top_k < logits.size(-1): |
logits, indices = torch.topk(logits, top_k) |
probs = torch.nn.functional.softmax(logits, dim=-1) |
next_token_id = indices[torch.multinomial(probs, 1)] |
else: |
probs = torch.nn.functional.softmax(logits, dim=-1) |
next_token_id = torch.multinomial(probs, num_samples=1) |
if have_non_space and (next_token_id == 0 or next_token_id==4): |
break |
else: |
have_non_space = True |
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1) |
logits_list.append(logits) |
del logits, probs |
model_response = decode(input_ids.squeeze(0).tolist()) |
try: |
move = model_response[len(game_state):].split(";")[0].split()[0] |
except IndexError: |
move = None |
return move, torch.stack(logits_list) if len(logits_list) > 0 else None |
def get_lc0_move(board, backend): |
gamestate = GameState(fen=board.fen()) |
input_planes = gamestate.as_input(backend) |
result = backend.evaluate(input_planes)[0] |
moves = gamestate.moves() |
policy_indices = gamestate.policy_indices() |
move_probs = np.array(result.p_softmax(*policy_indices)) |
try: |
best_move_idx = move_probs.argmax() |
except: |
return None |
best_move = moves[best_move_idx] |
return chess.Move.from_uci(best_move) |
def evaluate_position(fen, backend): |
gamestate = GameState(fen=fen) |
result = backend.evaluate(gamestate.as_input(backend))[0] |
return result.q() |
def reward_from_eval(before_eval, after_eval): |
diff = after_eval - before_eval |
return diff / (move_reward_scale_factor + abs(diff)) |
def backward_pass(loss): |
global grad_norm_history |
scaler.scale(loss).backward() |
if grad_clip != 0.0 or auto_clip: |
scaler.unscale_(optimizer) |
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 0.1) |
grad_norm_history.append(total_norm.item()) |
grad_norm_history = grad_norm_history[-grad_clip_max_size:] |
scaler.step(optimizer) |
scaler.update() |
optimizer.zero_grad(set_to_none=True) |
def play_game(): |
global top_k |
optimizer.zero_grad(set_to_none=True) |
torch.cuda.empty_cache() |
board = chess.Board() |
total_loss = 0 |
illegal_moves = 0 |
move_count = 0 |
moves_since_backward = 0 |
tot_reward = 0 |
tokens = [m.split(".")[-1] if "." in m else m for m in opening_line.split()] |
[board.push_san(m) for m in tokens] |
if move_num_in_gamestate: |
game_state = opening_line.rstrip() + " " |
else: |
game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in opening_line.split()]) |
fail = False |
while not board.is_game_over(): |
before_eval = evaluate_position(board.fen(), lc0_backend_evaluator) |
game_state += f"{board.fullmove_number if move_num_in_gamestate else ''}." |
model_move, logits = get_model_move(game_state, top_k) |
move_reward = -1 |
if model_move is None or logits is None: |
illegal_moves += 1 |
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) |
if pinch_hit_move is None: |
print("Failed game (lc0 couldn't find pinch-hit move)") |
fail = True |
tot_reward += move_reward |
move_count += 1 |
break |
game_state += f"{board.san(pinch_hit_move)} " |
board.push(pinch_hit_move) |
else: |
try: |
board.push(board.parse_san(model_move)) |
game_state += f"{model_move} " |
except: |
illegal_moves += 1 |
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) |
if pinch_hit_move is None: |
print("Failed game (lc0 couldn't find pinch-hit move)") |
fail = True |
tot_reward += move_reward |
move_count += 1 |
break |
game_state += f"{board.san(pinch_hit_move)} " |
board.push(pinch_hit_move) |
else: |
if not board.is_valid(): |
board.pop() |
illegal_moves += 1 |
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) |
if pinch_hit_move is None: |
print("Failed game (lc0 couldn't find pinch-hit move)") |
fail = True |
tot_reward += move_reward |
move_count += 1 |
break |
game_state += f"{board.san(pinch_hit_move)} " |
board.push(pinch_hit_move) |
else: |
after_eval = -evaluate_position(board.fen(), lc0_backend_evaluator) |
move_reward = reward_from_eval(before_eval, after_eval) |
tot_reward += move_reward |
if not board.is_game_over(): |
black_move = get_lc0_move(board, lc0_backend_opponent) |
if black_move is None: |
print("Failed game (lc0 couldn't find black move)") |
fail = True |
move_count += 1 |
break |
game_state += f"{board.san(black_move)} " |
board.push(black_move) |
if logits is not None: |
total_loss += torch.sum(torch.nn.functional.log_softmax(logits, dim=-1) * move_reward) |
logits_none = logits is None |
del logits |
moves_since_backward += 1 |
if move_count % update_freq == 0 and not board.is_game_over() and not logits_none: |
backward_pass(total_loss / moves_since_backward) |
total_loss = 0.0 |
moves_since_backward = 0 |
move_count += 1 |
if move_count == top_k_adj_moves: |
top_k = top_k - 1 |
if move_count >= max_moves: |
break |
if move_count % flush_every == 0: |
torch.cuda.empty_cache() |
if move_count >= top_k_adj_moves: |
top_k = top_k + 1 |
avg_reward = tot_reward / move_count |
scale_factor = torch.tensor([1.0], device=device) |
if move_count >= max_moves: |
result = "1/2-1/2" |
elif fail: |
result = "*" |
else: |
result = board.result() |
total_loss = total_loss / moves_since_backward |
if result == "0-1": |
scale_factor = torch.tensor([1.0 / decrease_factor], device=device) if avg_reward < 0 and illegal_moves <= max_illegal_moves else scale_factor |
elif result == "1-0": |
wdf = decrease_factor / 2.0 if avg_reward <= 0 else 1.0 / decrease_factor |
if illegal_moves == 0: |
scale_factor = torch.tensor([wdf], device=device) |
elif illegal_moves <= max_illegal_moves: |
scale_factor = torch.tensor([(1 + wdf) / 2], device=device) |
result = "1/2-1/2" |
else: |
result = "0-1" |
if total_loss.numel(): |
try: |
backward_pass(total_loss * scale_factor) |
except: |
print("Failed game (final backward pass, result not effected)") |
total_loss = 0.0 |
return avg_reward / scale_factor.item(), result, illegal_moves, move_count |
def get_lr(it): |
if it < warmup_iters: |
return learning_rate * it / warmup_iters |
if it > lr_decay_iters: |
return min_lr |
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) |
assert 0 <= decay_ratio <= 1 |
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
return min_lr + coeff * (learning_rate - min_lr) |
if wandb_log: |
import wandb |
wandb.init(project=wandb_project, name=wandb_run_name, config=config) |
while True: |
t0 = time.time() |
lr = get_lr(iter_num) |
for param_group in optimizer.param_groups: |
param_group['lr'] = lr |
opening_line = opening_lines[opening_line_index] |
if iter_num > 0 and iter_num % save_interval == 0: |
if auto_clip and len(grad_norm_history) >= grad_clip_start_size: |
grad_clip = max(min(np.percentile(grad_norm_history, grad_clip_percentile), max_grad_clip), min_grad_clip) |
config['grad_clip'] = grad_clip |
print(f"Auto adjusted grad_clip to {grad_clip}") |
if wandb_log: |
wandb.log({ |
"etc/iter": iter_num, |
"etc/lr": lr, |
"etc/grad_clip": grad_clip, |
"etc/games_played": games_played, |
}) |
raw_model = model.module if ddp else model |
checkpoint = { |
'model': raw_model.state_dict(), |
'optimizer': optimizer.state_dict(), |
'model_args': mamba_config, |
'iter_num': iter_num, |
"games_seen": games_played, |
'config': config, |
'opening_line_index': opening_line_index, |
'grad_norm_history': grad_norm_history, |
'win_rate_window': win_rate_window, |
'win_only_rate_window': win_only_rate_window, |
'best_wr': best_wr, |
'best_wor': best_wor |
} |
print(f"saving checkpoint to {out_dir}\n") |
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) |
game_reward, result, illegal_moves, move_count = play_game() |
games_played += 1 |
t1 = time.time() |
dt = t1 - t0 |
t0 = t1 |
score = 0.5 |
if result == "1-0": |
score = 1 |
elif result == "0-1": |
score = 0 |
if result != "*": |
win_rate_window.append(score) |
win_rate_window = win_rate_window[-window_size:] |
win_rate = sum(win_rate_window) / len(win_rate_window) |
win_only_rate_window.append(int(score)) |
win_only_rate_window = win_only_rate_window[-window_size:] |
win_only_rate = float(sum(win_only_rate_window)) / len(win_only_rate_window) |
if win_rate > best_wr: |
best_wr = win_rate |
raw_model = model.module if ddp else model |
checkpoint = { |
'model': raw_model.state_dict(), |
'optimizer': optimizer.state_dict(), |
'model_args': mamba_config, |
'iter_num': iter_num, |
"games_seen": games_played, |
'config': config, |
'opening_line_index': opening_line_index, |
'grad_norm_history': grad_norm_history, |
'win_rate_window': win_rate_window, |
'best_wr': best_wr, |
'best_wor': best_wor |
} |
print(f"saving checkpoint to {out_dir}\n") |
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wr{best_wr}.pt')) |
elif win_only_rate > best_wor: |
best_wor = win_only_rate |
raw_model = model.module if ddp else model |
checkpoint = { |
'model': raw_model.state_dict(), |
'optimizer': optimizer.state_dict(), |
'model_args': mamba_config, |
'iter_num': iter_num, |
"games_seen": games_played, |
'config': config, |
'opening_line_index': opening_line_index, |
'grad_norm_history': grad_norm_history, |
'win_rate_window': win_rate_window, |
'best_wr': best_wr, |
'best_wor': best_wor |
} |
print(f"saving checkpoint to {out_dir}\n") |
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wor{best_wor}.pt')) |
best_wor = max(best_wor, win_only_rate) |
print(f"Game {games_played} ({iter_num}, {(iter_num / len(opening_lines)) * 100.0:.3f}%): Score {score}, Reward {game_reward:.4f}, Illegal moves {illegal_moves} ({illegal_moves / move_count:.3%}), Total moves {move_count}, Win rate {win_rate:.3f}, Win only rate {win_only_rate:.3f}, time {dt * 1000:.2f}ms") |
if wandb_log: |
wandb.log({ |
"etc/iter": iter_num, |
"etc/lr": lr, |
"etc/grad_norm_mean": np.mean(grad_norm_history) if grad_norm_history else -1, |
"etc/grad_zero_pct": float(np.count_nonzero(grad_norm_history==0))/len(grad_norm_history) if grad_norm_history else -1, |
"etc/games_played": games_played, |
"eval/game_reward": game_reward, |
"eval/illegal_move_pct": illegal_moves / move_count, |
"eval/move_ct": move_count, |
"eval/win_rate": win_rate, |
"eval/win_only_rate": win_only_rate, |
}) |
iter_num += 1 |
opening_line_index += 1 |
if opening_line_index >= len(opening_lines): |
break |
if ddp: |
destroy_process_group() |