|
import math |
|
import os |
|
import time |
|
import pickle |
|
from contextlib import nullcontext |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.distributed import init_process_group, destroy_process_group |
|
from mamba_lm import MambaLM, MambaLMConfig |
|
import random |
|
import chess |
|
from lczero.backends import Weights, Backend, GameState |
|
|
|
|
|
out_dir = 'out/play' |
|
save_interval = 50 |
|
wandb_project = 'chess-training' |
|
wandb_run_name = 'lc0-training' |
|
init_from = 'resume' |
|
|
|
|
|
n_layer = 15 |
|
d_model = 256 |
|
dt_rank = 'auto' |
|
d_state = 16 |
|
vocab_size = 28 |
|
move_num_in_gamestate = False |
|
|
|
|
|
|
|
wandb_log = True |
|
wandb_project = 'mamba-rl' |
|
wandb_run_name = 'mamba_run' |
|
|
|
|
|
with open("openings.csv", "r") as file: |
|
lines = file.readlines()[1:] |
|
opening_lines = lines |
|
|
|
|
|
learning_rate = 1e-7 |
|
min_lr = 1e-8 |
|
warmup_iters = 600 |
|
lr_decay_iters = len(opening_lines) |
|
weight_decay = 1e-2 |
|
beta1 = 0.905 |
|
beta2 = 0.965 |
|
grad_clip = 0.5 |
|
min_grad_clip = 1e-3 |
|
max_grad_clip = 0.45 |
|
auto_clip = True |
|
grad_clip_start_size = 150 |
|
grad_clip_max_size = 600 |
|
grad_clip_percentile = 9 |
|
|
|
|
|
top_k = 2 |
|
top_k_adj_moves = 40 |
|
max_illegal_moves = 8 |
|
max_moves = 87 |
|
update_freq = 3 |
|
flush_every = 1 |
|
move_reward_scale_factor = 4.0 |
|
decrease_factor = 0.75 |
|
window_size = 300 |
|
|
|
|
|
|
|
backend = 'nccl' |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float32' |
|
compile = False |
|
|
|
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] |
|
|
|
config = {k: globals()[k] for k in config_keys} |
|
|
|
|
|
lc0_weights_opponent = Weights("./lc0/build/release/11258-32x4-se.pb.gz") |
|
lc0_backend_opponent = Backend(weights=lc0_weights_opponent) |
|
|
|
lc0_weights_evaluator = Weights("./lc0/build/release/11258-48x5-se.pb.gz") |
|
lc0_backend_evaluator = lc0_backend_opponent |
|
|
|
|
|
if move_num_in_gamestate: |
|
meta_path = os.path.join(os.path.join('data', 'chess'), 'meta.pkl') |
|
with open(meta_path, "rb") as f: |
|
meta = pickle.load(f) |
|
stoi, itos = meta["stoi"], meta["itos"] |
|
vocab_size = meta['vocab_size'] |
|
encode = lambda s: [stoi[c] for c in s] |
|
decode = lambda l: "".join([itos[i] for i in l]) |
|
else: |
|
stoi = {' ': 0, '.': 1, 'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6, 'f': 7, 'g': 8, 'h': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, 'B': 18, 'N': 19, 'R': 20, 'Q': 21, 'K': 22, 'O': 23, 'x': 24, '+': 25, '#': 26, '=': 27} |
|
itos = {0: ' ', 1: '.', 2: 'a', 3: 'b', 4: 'c', 5: 'd', 6: 'e', 7: 'f', 8: 'g', 9: 'h', 10: '1', 11: '2', 12: '3', 13: '4', 14: '5', 15: '6', 16: '7', 17: '8', 18: 'B', 19: 'N', 20: 'R', 21: 'Q', 22: 'K', 23: 'O', 24: 'x', 25: '+', 26: '#', 27: '='} |
|
for s in stoi: |
|
assert itos[stoi[s]] == s |
|
vocab_size = len(stoi) |
|
print(f"Vocab size {vocab_size}") |
|
encode = lambda s: [stoi[c] for c in s.replace('-', '')] |
|
decode = lambda l: "".join([itos[i] for i in l]).replace("OOO", "O-O-O").replace("OO", "O-O") |
|
|
|
|
|
|
|
|
|
mamba_config = MambaLMConfig( |
|
d_model=d_model, |
|
n_layers=n_layer, |
|
dt_rank=dt_rank, |
|
d_state=d_state, |
|
vocab_size=vocab_size |
|
) |
|
|
|
model = MambaLM(mamba_config) |
|
model.to(device) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2)) |
|
scaler = torch.cuda.amp.GradScaler(enabled=dtype == 'float16') |
|
|
|
|
|
if compile: |
|
print("compiling the model... (takes a ~minute)") |
|
model = torch.compile(model) |
|
|
|
ddp = int(os.environ.get('RANK', -1)) != -1 |
|
|
|
if ddp: |
|
model = DDP(model, device_ids=[ddp_local_rank]) |
|
|
|
win_rate_window = [] |
|
win_only_rate_window = [] |
|
|
|
if init_from == 'resume': |
|
print(f"Resuming training from {out_dir}") |
|
ckpt_path = os.path.join(out_dir, 'ckpt.pt') |
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
mamba_config = checkpoint['model_args'] |
|
state_dict = checkpoint['model'] |
|
|
|
|
|
unwanted_prefix = '_orig_mod.' |
|
for k, v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
iter_num = checkpoint['iter_num'] |
|
games_played = checkpoint['games_seen'] |
|
opening_line_index = checkpoint.get('opening_line_index', 0) |
|
win_rate_window = checkpoint.get('win_rate_window', []) |
|
win_only_rate_window = checkpoint.get('win_only_rate_window', []) |
|
best_wr = checkpoint.get('best_wr', 0.0) |
|
best_wor = checkpoint.get('best_wor', 0.0) |
|
if auto_clip: |
|
grad_clip = checkpoint['config']['grad_clip'] |
|
config['grad_clip'] = grad_clip |
|
grad_norm_history = checkpoint.get('grad_norm_history', []) |
|
else: |
|
grad_norm_history = [] |
|
else: |
|
best_wr = 0.0 |
|
best_wor = 0.0 |
|
grad_norm_history = [] |
|
games_played = 0 |
|
iter_num = 0 |
|
opening_line_index = 0 |
|
if auto_clip: |
|
grad_clip = 0 |
|
config['grad_clip'] = 0 |
|
|
|
|
|
def get_model_move(game_state, top_k): |
|
model.train() |
|
encoded_prompt = encode(game_state) |
|
input_ids = torch.tensor([encoded_prompt], dtype=torch.long, device=device) |
|
|
|
have_non_space = False |
|
logits_list = [] |
|
for _ in range(8): |
|
logits = model(input_ids)[0, -1, :] |
|
|
|
|
|
if top_k is not None and top_k < logits.size(-1): |
|
logits, indices = torch.topk(logits, top_k) |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
next_token_id = indices[torch.multinomial(probs, 1)] |
|
else: |
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
next_token_id = torch.multinomial(probs, num_samples=1) |
|
|
|
if have_non_space and (next_token_id == 0 or next_token_id==4): |
|
break |
|
else: |
|
have_non_space = True |
|
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1) |
|
logits_list.append(logits) |
|
del logits, probs |
|
|
|
|
|
model_response = decode(input_ids.squeeze(0).tolist()) |
|
try: |
|
move = model_response[len(game_state):].split(";")[0].split()[0] |
|
except IndexError: |
|
move = None |
|
|
|
return move, torch.stack(logits_list) if len(logits_list) > 0 else None |
|
|
|
def get_lc0_move(board, backend): |
|
gamestate = GameState(fen=board.fen()) |
|
input_planes = gamestate.as_input(backend) |
|
result = backend.evaluate(input_planes)[0] |
|
moves = gamestate.moves() |
|
policy_indices = gamestate.policy_indices() |
|
move_probs = np.array(result.p_softmax(*policy_indices)) |
|
try: |
|
best_move_idx = move_probs.argmax() |
|
except: |
|
return None |
|
best_move = moves[best_move_idx] |
|
return chess.Move.from_uci(best_move) |
|
|
|
def evaluate_position(fen, backend): |
|
gamestate = GameState(fen=fen) |
|
result = backend.evaluate(gamestate.as_input(backend))[0] |
|
return result.q() |
|
|
|
def reward_from_eval(before_eval, after_eval): |
|
diff = after_eval - before_eval |
|
return diff / (move_reward_scale_factor + abs(diff)) |
|
|
|
def backward_pass(loss): |
|
global grad_norm_history |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
if grad_clip != 0.0 or auto_clip: |
|
scaler.unscale_(optimizer) |
|
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip if grad_clip != 0.0 else 0.1) |
|
grad_norm_history.append(total_norm.item()) |
|
grad_norm_history = grad_norm_history[-grad_clip_max_size:] |
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
def play_game(): |
|
global top_k |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
torch.cuda.empty_cache() |
|
board = chess.Board() |
|
total_loss = 0 |
|
illegal_moves = 0 |
|
move_count = 0 |
|
moves_since_backward = 0 |
|
tot_reward = 0 |
|
|
|
|
|
tokens = [m.split(".")[-1] if "." in m else m for m in opening_line.split()] |
|
[board.push_san(m) for m in tokens] |
|
if move_num_in_gamestate: |
|
game_state = opening_line.rstrip() + " " |
|
else: |
|
game_state = ' '.join(['.' + m.split(".")[-1] if "." in m else m for m in opening_line.split()]) |
|
fail = False |
|
|
|
while not board.is_game_over(): |
|
before_eval = evaluate_position(board.fen(), lc0_backend_evaluator) |
|
game_state += f"{board.fullmove_number if move_num_in_gamestate else ''}." |
|
model_move, logits = get_model_move(game_state, top_k) |
|
move_reward = -1 |
|
|
|
if model_move is None or logits is None: |
|
illegal_moves += 1 |
|
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) |
|
if pinch_hit_move is None: |
|
print("Failed game (lc0 couldn't find pinch-hit move)") |
|
fail = True |
|
tot_reward += move_reward |
|
move_count += 1 |
|
break |
|
game_state += f"{board.san(pinch_hit_move)} " |
|
board.push(pinch_hit_move) |
|
else: |
|
try: |
|
|
|
board.push(board.parse_san(model_move)) |
|
game_state += f"{model_move} " |
|
except: |
|
illegal_moves += 1 |
|
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) |
|
if pinch_hit_move is None: |
|
print("Failed game (lc0 couldn't find pinch-hit move)") |
|
fail = True |
|
tot_reward += move_reward |
|
move_count += 1 |
|
break |
|
game_state += f"{board.san(pinch_hit_move)} " |
|
board.push(pinch_hit_move) |
|
else: |
|
if not board.is_valid(): |
|
board.pop() |
|
illegal_moves += 1 |
|
pinch_hit_move = get_lc0_move(board, lc0_backend_opponent) |
|
if pinch_hit_move is None: |
|
print("Failed game (lc0 couldn't find pinch-hit move)") |
|
fail = True |
|
tot_reward += move_reward |
|
move_count += 1 |
|
break |
|
game_state += f"{board.san(pinch_hit_move)} " |
|
board.push(pinch_hit_move) |
|
else: |
|
after_eval = -evaluate_position(board.fen(), lc0_backend_evaluator) |
|
move_reward = reward_from_eval(before_eval, after_eval) |
|
|
|
tot_reward += move_reward |
|
if not board.is_game_over(): |
|
black_move = get_lc0_move(board, lc0_backend_opponent) |
|
if black_move is None: |
|
print("Failed game (lc0 couldn't find black move)") |
|
fail = True |
|
move_count += 1 |
|
break |
|
game_state += f"{board.san(black_move)} " |
|
board.push(black_move) |
|
|
|
if logits is not None: |
|
total_loss += torch.sum(torch.nn.functional.log_softmax(logits, dim=-1) * move_reward) |
|
logits_none = logits is None |
|
del logits |
|
moves_since_backward += 1 |
|
if move_count % update_freq == 0 and not board.is_game_over() and not logits_none: |
|
backward_pass(total_loss / moves_since_backward) |
|
total_loss = 0.0 |
|
moves_since_backward = 0 |
|
move_count += 1 |
|
if move_count == top_k_adj_moves: |
|
top_k = top_k - 1 |
|
if move_count >= max_moves: |
|
break |
|
if move_count % flush_every == 0: |
|
torch.cuda.empty_cache() |
|
|
|
if move_count >= top_k_adj_moves: |
|
top_k = top_k + 1 |
|
|
|
avg_reward = tot_reward / move_count |
|
|
|
scale_factor = torch.tensor([1.0], device=device) |
|
if move_count >= max_moves: |
|
result = "1/2-1/2" |
|
elif fail: |
|
result = "*" |
|
else: |
|
result = board.result() |
|
total_loss = total_loss / moves_since_backward |
|
if result == "0-1": |
|
|
|
scale_factor = torch.tensor([1.0 / decrease_factor], device=device) if avg_reward < 0 and illegal_moves <= max_illegal_moves else scale_factor |
|
|
|
elif result == "1-0": |
|
wdf = decrease_factor / 2.0 if avg_reward <= 0 else 1.0 / decrease_factor |
|
|
|
|
|
if illegal_moves == 0: |
|
scale_factor = torch.tensor([wdf], device=device) |
|
|
|
elif illegal_moves <= max_illegal_moves: |
|
scale_factor = torch.tensor([(1 + wdf) / 2], device=device) |
|
|
|
result = "1/2-1/2" |
|
else: |
|
result = "0-1" |
|
|
|
|
|
if total_loss.numel(): |
|
try: |
|
backward_pass(total_loss * scale_factor) |
|
except: |
|
print("Failed game (final backward pass, result not effected)") |
|
total_loss = 0.0 |
|
|
|
|
|
return avg_reward / scale_factor.item(), result, illegal_moves, move_count |
|
|
|
|
|
def get_lr(it): |
|
|
|
if it < warmup_iters: |
|
return learning_rate * it / warmup_iters |
|
|
|
if it > lr_decay_iters: |
|
return min_lr |
|
|
|
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) |
|
assert 0 <= decay_ratio <= 1 |
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
return min_lr + coeff * (learning_rate - min_lr) |
|
|
|
|
|
if wandb_log: |
|
import wandb |
|
wandb.init(project=wandb_project, name=wandb_run_name, config=config) |
|
|
|
while True: |
|
t0 = time.time() |
|
lr = get_lr(iter_num) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
opening_line = opening_lines[opening_line_index] |
|
|
|
if iter_num > 0 and iter_num % save_interval == 0: |
|
if auto_clip and len(grad_norm_history) >= grad_clip_start_size: |
|
grad_clip = max(min(np.percentile(grad_norm_history, grad_clip_percentile), max_grad_clip), min_grad_clip) |
|
config['grad_clip'] = grad_clip |
|
print(f"Auto adjusted grad_clip to {grad_clip}") |
|
|
|
|
|
if wandb_log: |
|
wandb.log({ |
|
"etc/iter": iter_num, |
|
"etc/lr": lr, |
|
"etc/grad_clip": grad_clip, |
|
"etc/games_played": games_played, |
|
}) |
|
|
|
|
|
raw_model = model.module if ddp else model |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'model_args': mamba_config, |
|
'iter_num': iter_num, |
|
"games_seen": games_played, |
|
'config': config, |
|
'opening_line_index': opening_line_index, |
|
'grad_norm_history': grad_norm_history, |
|
'win_rate_window': win_rate_window, |
|
'win_only_rate_window': win_only_rate_window, |
|
'best_wr': best_wr, |
|
'best_wor': best_wor |
|
} |
|
print(f"saving checkpoint to {out_dir}\n") |
|
torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) |
|
|
|
|
|
game_reward, result, illegal_moves, move_count = play_game() |
|
games_played += 1 |
|
|
|
|
|
|
|
|
|
t1 = time.time() |
|
dt = t1 - t0 |
|
t0 = t1 |
|
score = 0.5 |
|
if result == "1-0": |
|
score = 1 |
|
elif result == "0-1": |
|
score = 0 |
|
if result != "*": |
|
win_rate_window.append(score) |
|
win_rate_window = win_rate_window[-window_size:] |
|
win_rate = sum(win_rate_window) / len(win_rate_window) |
|
win_only_rate_window.append(int(score)) |
|
win_only_rate_window = win_only_rate_window[-window_size:] |
|
win_only_rate = float(sum(win_only_rate_window)) / len(win_only_rate_window) |
|
if win_rate > best_wr: |
|
best_wr = win_rate |
|
raw_model = model.module if ddp else model |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'model_args': mamba_config, |
|
'iter_num': iter_num, |
|
"games_seen": games_played, |
|
'config': config, |
|
'opening_line_index': opening_line_index, |
|
'grad_norm_history': grad_norm_history, |
|
'win_rate_window': win_rate_window, |
|
'best_wr': best_wr, |
|
'best_wor': best_wor |
|
} |
|
print(f"saving checkpoint to {out_dir}\n") |
|
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wr{best_wr}.pt')) |
|
elif win_only_rate > best_wor: |
|
best_wor = win_only_rate |
|
raw_model = model.module if ddp else model |
|
checkpoint = { |
|
'model': raw_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), |
|
'model_args': mamba_config, |
|
'iter_num': iter_num, |
|
"games_seen": games_played, |
|
'config': config, |
|
'opening_line_index': opening_line_index, |
|
'grad_norm_history': grad_norm_history, |
|
'win_rate_window': win_rate_window, |
|
'best_wr': best_wr, |
|
'best_wor': best_wor |
|
} |
|
print(f"saving checkpoint to {out_dir}\n") |
|
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{games_played}g_wor{best_wor}.pt')) |
|
best_wor = max(best_wor, win_only_rate) |
|
print(f"Game {games_played} ({iter_num}, {(iter_num / len(opening_lines)) * 100.0:.3f}%): Score {score}, Reward {game_reward:.4f}, Illegal moves {illegal_moves} ({illegal_moves / move_count:.3%}), Total moves {move_count}, Win rate {win_rate:.3f}, Win only rate {win_only_rate:.3f}, time {dt * 1000:.2f}ms") |
|
if wandb_log: |
|
wandb.log({ |
|
"etc/iter": iter_num, |
|
"etc/lr": lr, |
|
"etc/grad_norm_mean": np.mean(grad_norm_history) if grad_norm_history else -1, |
|
"etc/grad_zero_pct": float(np.count_nonzero(grad_norm_history==0))/len(grad_norm_history) if grad_norm_history else -1, |
|
"etc/games_played": games_played, |
|
"eval/game_reward": game_reward, |
|
"eval/illegal_move_pct": illegal_moves / move_count, |
|
"eval/move_ct": move_count, |
|
"eval/win_rate": win_rate, |
|
"eval/win_only_rate": win_only_rate, |
|
}) |
|
|
|
iter_num += 1 |
|
opening_line_index += 1 |
|
|
|
|
|
if opening_line_index >= len(opening_lines): |
|
break |
|
|
|
if ddp: |
|
destroy_process_group() |
|
|