Upload model_aggressive.py
Browse files- model_aggressive.py +769 -0
model_aggressive.py
ADDED
|
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NOTE FOR COLAB USERS: Run in a separate cell first:
|
| 2 |
+
# !pip -q install chess numpy torch matplotlib pandas
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Aggressive GRPO Chess Agent β T4/Colab Optimized
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os, sys, csv, time, math, shutil, argparse, random
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg')
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import chess
|
| 17 |
+
except ImportError:
|
| 18 |
+
os.system("pip install -q chess")
|
| 19 |
+
import chess
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
# ββ Hardware flags βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
torch.backends.cudnn.benchmark = True
|
| 27 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 28 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 29 |
+
if hasattr(torch, 'set_float32_matmul_precision'):
|
| 30 |
+
torch.set_float32_matmul_precision('high')
|
| 31 |
+
|
| 32 |
+
# ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
PIECE_VAL = {
|
| 34 |
+
chess.PAWN: 1.0, chess.KNIGHT: 3.0, chess.BISHOP: 3.2,
|
| 35 |
+
chess.ROOK: 5.0, chess.QUEEN: 9.0, chess.KING: 0.0,
|
| 36 |
+
}
|
| 37 |
+
RANDOM_BASELINE_ELO = 800 # estimated ELO of uniform-random player
|
| 38 |
+
|
| 39 |
+
CONFIG = {
|
| 40 |
+
"num_envs": 256,
|
| 41 |
+
"grpo_group_size": 8, # G envs per group, all start from same opening position
|
| 42 |
+
"ppo_epochs": 3,
|
| 43 |
+
"mini_batch_size": 4096,
|
| 44 |
+
"learning_rate": 2e-4,
|
| 45 |
+
"weight_decay": 1e-4,
|
| 46 |
+
"gamma": 0.98, # lower β discount future more β prefer fast wins
|
| 47 |
+
"clip_epsilon": 0.15,
|
| 48 |
+
"entropy_coef": 0.02, # low β exploit aggressive lines
|
| 49 |
+
"value_coef": 0.5,
|
| 50 |
+
"max_steps": 100,
|
| 51 |
+
"opening_max_moves": 10, # randomize opening for GRPO diversity
|
| 52 |
+
"checkpoint_dir": "./checkpoints",
|
| 53 |
+
"save_interval": 50,
|
| 54 |
+
"log_interval": 1,
|
| 55 |
+
"elo_eval_interval": 100, # evaluate ELO every N iterations
|
| 56 |
+
"elo_eval_games": 32,
|
| 57 |
+
"max_runtime_hours": 4.5, # auto-save + download before Colab kills session
|
| 58 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 59 |
+
"seed": 42,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# ββ Action Space βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
class ActionMapper:
|
| 64 |
+
__slots__ = ['move_to_idx', 'idx_to_move', 'num_actions']
|
| 65 |
+
def __init__(self):
|
| 66 |
+
self.move_to_idx: dict[str, int] = {}
|
| 67 |
+
self.idx_to_move: list[str] = []
|
| 68 |
+
idx = 0
|
| 69 |
+
for f in range(64):
|
| 70 |
+
for t in range(64):
|
| 71 |
+
if f == t: continue
|
| 72 |
+
uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t]
|
| 73 |
+
self.move_to_idx[uci] = idx
|
| 74 |
+
self.idx_to_move.append(uci)
|
| 75 |
+
idx += 1
|
| 76 |
+
if chess.square_rank(f) in (1, 6) and \
|
| 77 |
+
abs(chess.square_file(f) - chess.square_file(t)) <= 1:
|
| 78 |
+
for promo in "nbrq":
|
| 79 |
+
puci = uci + promo
|
| 80 |
+
self.move_to_idx[puci] = idx
|
| 81 |
+
self.idx_to_move.append(puci)
|
| 82 |
+
idx += 1
|
| 83 |
+
self.num_actions = idx
|
| 84 |
+
|
| 85 |
+
ACTION_MAPPER = ActionMapper()
|
| 86 |
+
|
| 87 |
+
# ββ Board Encoding βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
+
def populate_states_fast(envs: list, active_mask: np.ndarray,
|
| 89 |
+
bbs_np: np.ndarray, meta_np: np.ndarray) -> None:
|
| 90 |
+
"""Fill bbs_np [B,12] int64 and meta_np [B,3] float32 for active envs."""
|
| 91 |
+
for b in range(len(envs)):
|
| 92 |
+
if not active_mask[b]: continue
|
| 93 |
+
env = envs[b]
|
| 94 |
+
w = env.occupied_co[chess.WHITE]
|
| 95 |
+
bc = env.occupied_co[chess.BLACK]
|
| 96 |
+
bbs_np[b, 0] = env.pawns & w; bbs_np[b, 1] = env.knights & w
|
| 97 |
+
bbs_np[b, 2] = env.bishops & w; bbs_np[b, 3] = env.rooks & w
|
| 98 |
+
bbs_np[b, 4] = env.queens & w; bbs_np[b, 5] = env.kings & w
|
| 99 |
+
bbs_np[b, 6] = env.pawns & bc; bbs_np[b, 7] = env.knights & bc
|
| 100 |
+
bbs_np[b, 8] = env.bishops & bc; bbs_np[b, 9] = env.rooks & bc
|
| 101 |
+
bbs_np[b, 10] = env.queens & bc; bbs_np[b, 11] = env.kings & bc
|
| 102 |
+
meta_np[b, 0] = 1.0 if env.turn else -1.0
|
| 103 |
+
meta_np[b, 1] = float(env.castling_rights) / 15.0 # [0,1]
|
| 104 |
+
meta_np[b, 2] = 1.0 if env.ep_square is not None else 0.0
|
| 105 |
+
|
| 106 |
+
def get_legal_masks(envs: list, active_mask: np.ndarray):
|
| 107 |
+
masks = np.zeros((len(envs), ACTION_MAPPER.num_actions), dtype=np.bool_)
|
| 108 |
+
moves_list = [None] * len(envs)
|
| 109 |
+
for b in range(len(envs)):
|
| 110 |
+
if not active_mask[b]: continue
|
| 111 |
+
legal = list(envs[b].legal_moves)
|
| 112 |
+
moves_list[b] = legal
|
| 113 |
+
for m in legal:
|
| 114 |
+
masks[b, ACTION_MAPPER.move_to_idx[m.uci()]] = True
|
| 115 |
+
return masks, moves_list
|
| 116 |
+
|
| 117 |
+
# ββ Neural Network βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
class ChessNet(nn.Module):
|
| 119 |
+
def __init__(self, res_blocks: int = 8, channels: int = 128):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.conv_in = nn.Conv2d(14, channels, 3, padding=1, bias=False)
|
| 122 |
+
self.bn_in = nn.BatchNorm2d(channels)
|
| 123 |
+
self.res_blocks = nn.ModuleList([
|
| 124 |
+
nn.Sequential(
|
| 125 |
+
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
|
| 126 |
+
nn.BatchNorm2d(channels), nn.ReLU(inplace=True),
|
| 127 |
+
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
|
| 128 |
+
nn.BatchNorm2d(channels),
|
| 129 |
+
) for _ in range(res_blocks)
|
| 130 |
+
])
|
| 131 |
+
self.policy_head = nn.Sequential(
|
| 132 |
+
nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32),
|
| 133 |
+
nn.ReLU(inplace=True), nn.Flatten(),
|
| 134 |
+
nn.Linear(32 * 64, ACTION_MAPPER.num_actions),
|
| 135 |
+
)
|
| 136 |
+
# No Tanh β shaped rewards exceed [-1,1]; unbounded linear output
|
| 137 |
+
self.value_head = nn.Sequential(
|
| 138 |
+
nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32),
|
| 139 |
+
nn.ReLU(inplace=True), nn.Flatten(),
|
| 140 |
+
nn.Linear(32 * 64, 256), nn.ReLU(inplace=True),
|
| 141 |
+
nn.Linear(256, 1),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
x = F.relu(self.bn_in(self.conv_in(x)), inplace=True)
|
| 146 |
+
for blk in self.res_blocks:
|
| 147 |
+
x = F.relu(x + blk(x), inplace=True)
|
| 148 |
+
return self.policy_head(x), self.value_head(x)
|
| 149 |
+
|
| 150 |
+
# ββ ELO Tracker βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 151 |
+
class ELOTracker:
|
| 152 |
+
def __init__(self, initial_elo: float = 1200.0, K: float = 32.0):
|
| 153 |
+
self.elo = initial_elo
|
| 154 |
+
self.K = K
|
| 155 |
+
|
| 156 |
+
def expected(self, opp_elo: float) -> float:
|
| 157 |
+
return 1.0 / (1.0 + 10.0 ** ((opp_elo - self.elo) / 400.0))
|
| 158 |
+
|
| 159 |
+
def update(self, score: float, opp_elo: float) -> None:
|
| 160 |
+
self.elo += self.K * (score - self.expected(opp_elo))
|
| 161 |
+
|
| 162 |
+
# ββ Opening Position Generator βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 163 |
+
def get_opening_position(max_moves: int = 10) -> chess.Board:
|
| 164 |
+
"""Play 0..max_moves random half-moves from start for GRPO diversity."""
|
| 165 |
+
board = chess.Board()
|
| 166 |
+
for _ in range(random.randint(0, max_moves)):
|
| 167 |
+
if board.is_game_over(): break
|
| 168 |
+
board.push(random.choice(list(board.legal_moves)))
|
| 169 |
+
return chess.Board(board.fen()) # detached copy
|
| 170 |
+
|
| 171 |
+
# ββ Auto-download ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 172 |
+
def auto_download(checkpoint_dir: str) -> None:
|
| 173 |
+
"""Sync to Google Drive if mounted, else trigger browser downloads."""
|
| 174 |
+
try:
|
| 175 |
+
from google.colab import files as _cf
|
| 176 |
+
drive_dst = '/content/drive/MyDrive/chess_agent'
|
| 177 |
+
if os.path.exists('/content/drive/MyDrive'):
|
| 178 |
+
os.makedirs(drive_dst, exist_ok=True)
|
| 179 |
+
shutil.copytree(checkpoint_dir, drive_dst, dirs_exist_ok=True)
|
| 180 |
+
print(f"[AutoSave] Synced β {drive_dst}")
|
| 181 |
+
else:
|
| 182 |
+
for fname in ['best.pt', 'latest.pt', 'training_log.csv',
|
| 183 |
+
'elo_log.csv', 'training_performance.png']:
|
| 184 |
+
fpath = os.path.join(checkpoint_dir, fname)
|
| 185 |
+
if os.path.exists(fpath):
|
| 186 |
+
_cf.download(fpath)
|
| 187 |
+
print(f"[AutoSave] Downloaded {fname}")
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"[AutoSave] {e}")
|
| 190 |
+
|
| 191 |
+
# ββ GRPO Trainer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
class GRPOTrainer:
|
| 193 |
+
|
| 194 |
+
def __init__(self):
|
| 195 |
+
self.device = CONFIG["device"]
|
| 196 |
+
|
| 197 |
+
_model = ChessNet(res_blocks=8, channels=128)
|
| 198 |
+
_model = _model.to(self.device).to(memory_format=torch.channels_last)
|
| 199 |
+
try:
|
| 200 |
+
print("Compiling model (reduce-overhead)β¦")
|
| 201 |
+
self.model = torch.compile(_model, mode="reduce-overhead")
|
| 202 |
+
except Exception:
|
| 203 |
+
self.model = _model
|
| 204 |
+
|
| 205 |
+
self.optimizer = torch.optim.AdamW(
|
| 206 |
+
self.model.parameters(),
|
| 207 |
+
lr=CONFIG["learning_rate"],
|
| 208 |
+
weight_decay=CONFIG["weight_decay"],
|
| 209 |
+
fused=torch.cuda.is_available(),
|
| 210 |
+
)
|
| 211 |
+
self.scaler = torch.amp.GradScaler('cuda')
|
| 212 |
+
self.start_iter = 0
|
| 213 |
+
self.best_win_rate = 0.0
|
| 214 |
+
self.elo_tracker = ELOTracker()
|
| 215 |
+
|
| 216 |
+
# Shared shift tensor for bit-unpacking (avoid repeated allocation)
|
| 217 |
+
self.shifts = torch.arange(64, dtype=torch.int64,
|
| 218 |
+
device=self.device).view(1, 1, 64)
|
| 219 |
+
|
| 220 |
+
os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)
|
| 221 |
+
self.log_file = os.path.join(CONFIG["checkpoint_dir"], "training_log.csv")
|
| 222 |
+
self.elo_log_file = os.path.join(CONFIG["checkpoint_dir"], "elo_log.csv")
|
| 223 |
+
|
| 224 |
+
if not os.path.exists(self.log_file):
|
| 225 |
+
with open(self.log_file, "w", newline="") as f:
|
| 226 |
+
csv.writer(f).writerow([
|
| 227 |
+
"iteration", "p_loss", "v_loss", "v_mean", "fps",
|
| 228 |
+
"win_rate", "draw_rate", "check_rate", "capture_rate", "avg_game_len",
|
| 229 |
+
])
|
| 230 |
+
if not os.path.exists(self.elo_log_file):
|
| 231 |
+
with open(self.elo_log_file, "w", newline="") as f:
|
| 232 |
+
csv.writer(f).writerow(
|
| 233 |
+
["iteration", "elo", "eval_wins", "eval_draws", "eval_losses"])
|
| 234 |
+
|
| 235 |
+
self._init_checkpointing()
|
| 236 |
+
|
| 237 |
+
# ββ Checkpointing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 238 |
+
def _init_checkpointing(self) -> None:
|
| 239 |
+
latest = os.path.join(CONFIG["checkpoint_dir"], "latest.pt")
|
| 240 |
+
if not os.path.exists(latest):
|
| 241 |
+
return
|
| 242 |
+
try:
|
| 243 |
+
ckpt = torch.load(latest, map_location=self.device, weights_only=False)
|
| 244 |
+
sd = ckpt['model_state_dict']
|
| 245 |
+
# Handle compiled (_orig_mod. prefix) vs uncompiled state dicts
|
| 246 |
+
loaded = False
|
| 247 |
+
for attempt in [
|
| 248 |
+
sd,
|
| 249 |
+
{k.replace('_orig_mod.', ''): v for k, v in sd.items()},
|
| 250 |
+
{'_orig_mod.' + k: v for k, v in sd.items()},
|
| 251 |
+
]:
|
| 252 |
+
try:
|
| 253 |
+
self.model.load_state_dict(attempt); loaded = True; break
|
| 254 |
+
except RuntimeError:
|
| 255 |
+
continue
|
| 256 |
+
if not loaded:
|
| 257 |
+
raise RuntimeError("All state dict key variants failed.")
|
| 258 |
+
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
| 259 |
+
self.scaler.load_state_dict(ckpt['scaler_state_dict'])
|
| 260 |
+
self.start_iter = ckpt.get('iteration', 0) + 1
|
| 261 |
+
self.elo_tracker.elo = ckpt.get('elo', 1200.0)
|
| 262 |
+
self.best_win_rate = ckpt.get('best_win_rate', 0.0)
|
| 263 |
+
print(f"Resumed from iter {self.start_iter} | "
|
| 264 |
+
f"ELO {self.elo_tracker.elo:.0f} | best_win {self.best_win_rate:.3f}")
|
| 265 |
+
except Exception as e:
|
| 266 |
+
print(f"Checkpoint load failed ({e}). Starting fresh.")
|
| 267 |
+
|
| 268 |
+
def save_checkpoint(self, iteration: int, is_best: bool = False) -> None:
|
| 269 |
+
ckpt = {
|
| 270 |
+
'iteration': iteration,
|
| 271 |
+
'model_state_dict': self.model.state_dict(),
|
| 272 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 273 |
+
'scaler_state_dict': self.scaler.state_dict(),
|
| 274 |
+
'elo': self.elo_tracker.elo,
|
| 275 |
+
'best_win_rate': self.best_win_rate,
|
| 276 |
+
'config': CONFIG,
|
| 277 |
+
}
|
| 278 |
+
cdir = CONFIG["checkpoint_dir"]
|
| 279 |
+
path = os.path.join(cdir, f"iter_{iteration:04d}.pt")
|
| 280 |
+
# Atomic write: write to .tmp then os.replace (single syscall, crash-safe)
|
| 281 |
+
torch.save(ckpt, path + ".tmp"); os.replace(path + ".tmp", path)
|
| 282 |
+
latest = os.path.join(cdir, "latest.pt")
|
| 283 |
+
shutil.copy2(path, latest + ".tmp"); os.replace(latest + ".tmp", latest)
|
| 284 |
+
if is_best:
|
| 285 |
+
best = os.path.join(cdir, "best.pt")
|
| 286 |
+
shutil.copy2(path, best + ".tmp"); os.replace(best + ".tmp", best)
|
| 287 |
+
|
| 288 |
+
# ββ ELO Evaluation (batched, greedy) ββββββββββββββββββββββββββββββββββββββ
|
| 289 |
+
def _elo_game_done(self, board: chess.Board, idx: int, agent_color,
|
| 290 |
+
scores: np.ndarray, active: np.ndarray) -> None:
|
| 291 |
+
if board.is_game_over():
|
| 292 |
+
res = board.result()
|
| 293 |
+
if (res == "1-0" and agent_color == chess.WHITE) or \
|
| 294 |
+
(res == "0-1" and agent_color == chess.BLACK):
|
| 295 |
+
scores[idx] = 1.0
|
| 296 |
+
elif res == "1/2-1/2":
|
| 297 |
+
scores[idx] = 0.5
|
| 298 |
+
else:
|
| 299 |
+
scores[idx] = 0.0
|
| 300 |
+
active[idx] = False
|
| 301 |
+
|
| 302 |
+
def evaluate_elo(self, n_games: int = 32, max_ply: int = 200) -> tuple:
|
| 303 |
+
"""
|
| 304 |
+
Play n_games vs random opponent (batched GPU for agent moves).
|
| 305 |
+
Half games as White, half as Black.
|
| 306 |
+
Returns (wins, draws, losses) from agent's perspective.
|
| 307 |
+
"""
|
| 308 |
+
self.model.eval()
|
| 309 |
+
boards = [chess.Board() for _ in range(n_games)]
|
| 310 |
+
agent_colors = [chess.WHITE if i % 2 == 0 else chess.BLACK
|
| 311 |
+
for i in range(n_games)]
|
| 312 |
+
scores = np.full(n_games, 0.5, dtype=np.float32) # default: draw
|
| 313 |
+
active = np.ones(n_games, dtype=bool)
|
| 314 |
+
bbs_sub = np.zeros((n_games, 12), dtype=np.int64)
|
| 315 |
+
meta_sub= np.zeros((n_games, 3), dtype=np.float32)
|
| 316 |
+
|
| 317 |
+
for _ in range(max_ply):
|
| 318 |
+
if not active.any(): break
|
| 319 |
+
|
| 320 |
+
# Random moves (opponent turns) β CPU
|
| 321 |
+
for i in [i for i in range(n_games)
|
| 322 |
+
if active[i] and boards[i].turn != agent_colors[i]]:
|
| 323 |
+
legal = list(boards[i].legal_moves)
|
| 324 |
+
if legal: boards[i].push(random.choice(legal))
|
| 325 |
+
self._elo_game_done(boards[i], i, agent_colors[i], scores, active)
|
| 326 |
+
|
| 327 |
+
# Agent moves (batched GPU)
|
| 328 |
+
ag_idx = [i for i in range(n_games)
|
| 329 |
+
if active[i] and boards[i].turn == agent_colors[i]]
|
| 330 |
+
if not ag_idx:
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
n = len(ag_idx)
|
| 334 |
+
sub = [boards[i] for i in ag_idx]
|
| 335 |
+
act_sub = np.ones(n, dtype=bool)
|
| 336 |
+
populate_states_fast(sub, act_sub, bbs_sub[:n], meta_sub[:n])
|
| 337 |
+
|
| 338 |
+
bbs_t = torch.tensor(bbs_sub[:n], dtype=torch.int64, device=self.device)
|
| 339 |
+
unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).float().view(n, 12, 8, 8)
|
| 340 |
+
state = torch.zeros(n, 14, 8, 8, device=self.device, dtype=torch.float32)
|
| 341 |
+
state[:, :12] = unpacked
|
| 342 |
+
state[:, 12] = torch.tensor(meta_sub[:n, 0], device=self.device).view(n, 1, 1).expand(n, 8, 8)
|
| 343 |
+
state[:, 13] = torch.tensor(meta_sub[:n, 1], device=self.device).view(n, 1, 1).expand(n, 8, 8)
|
| 344 |
+
for lj in range(n):
|
| 345 |
+
if meta_sub[lj, 2]:
|
| 346 |
+
state[lj, 13, 0, 1] = float(meta_sub[lj, 2])
|
| 347 |
+
|
| 348 |
+
with torch.no_grad(), torch.amp.autocast('cuda'):
|
| 349 |
+
logits, _ = self.model(state.to(memory_format=torch.channels_last))
|
| 350 |
+
logits = logits.float()
|
| 351 |
+
|
| 352 |
+
masks_np, legal_lists = get_legal_masks(sub, act_sub)
|
| 353 |
+
masks_t = torch.tensor(masks_np, dtype=torch.bool, device=self.device)
|
| 354 |
+
logits = torch.where(masks_t, logits,
|
| 355 |
+
torch.tensor(-60000.0, device=self.device))
|
| 356 |
+
best_acts = logits.argmax(dim=-1).cpu().numpy() # greedy for evaluation
|
| 357 |
+
|
| 358 |
+
for lj, gi in enumerate(ag_idx):
|
| 359 |
+
if not active[gi]: continue
|
| 360 |
+
move_uci = ACTION_MAPPER.idx_to_move[best_acts[lj]]
|
| 361 |
+
move = chess.Move.from_uci(move_uci)
|
| 362 |
+
legal = legal_lists[lj] or list(boards[gi].legal_moves)
|
| 363 |
+
if not legal:
|
| 364 |
+
active[gi] = False; continue
|
| 365 |
+
if move not in legal:
|
| 366 |
+
move = random.choice(legal)
|
| 367 |
+
boards[gi].push(move)
|
| 368 |
+
self._elo_game_done(boards[gi], gi, agent_colors[gi], scores, active)
|
| 369 |
+
|
| 370 |
+
wins = int((scores == 1.0).sum())
|
| 371 |
+
draws = int((scores == 0.5).sum())
|
| 372 |
+
losses = int((scores == 0.0).sum())
|
| 373 |
+
for s in scores:
|
| 374 |
+
self.elo_tracker.update(float(s), RANDOM_BASELINE_ELO)
|
| 375 |
+
return wins, draws, losses
|
| 376 |
+
|
| 377 |
+
# ββ Main Training Loop βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 378 |
+
def train(self, num_iterations: int) -> None:
|
| 379 |
+
B = CONFIG["num_envs"]
|
| 380 |
+
max_steps = CONFIG["max_steps"]
|
| 381 |
+
G = CONFIG["grpo_group_size"]
|
| 382 |
+
num_groups= B // G
|
| 383 |
+
gamma = CONFIG["gamma"]
|
| 384 |
+
t_start = time.time()
|
| 385 |
+
max_rt = CONFIG["max_runtime_hours"] * 3600.0
|
| 386 |
+
|
| 387 |
+
# ββ Preallocate GPU buffers (int8/bool minimizes VRAM footprint) ββββββ
|
| 388 |
+
states_buf = torch.zeros((max_steps, B, 14, 8, 8), dtype=torch.int8, device=self.device)
|
| 389 |
+
actions_buf = torch.zeros((max_steps, B), dtype=torch.int16, device=self.device)
|
| 390 |
+
logprobs_buf= torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
|
| 391 |
+
values_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
|
| 392 |
+
rewards_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
|
| 393 |
+
dones_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device)
|
| 394 |
+
active_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device)
|
| 395 |
+
|
| 396 |
+
bbs_np = np.zeros((B, 12), dtype=np.int64) # int64: no astype copy needed
|
| 397 |
+
meta_np = np.zeros((B, 3), dtype=np.float32)
|
| 398 |
+
|
| 399 |
+
vram_gb = (torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 400 |
+
if torch.cuda.is_available() else 0.0)
|
| 401 |
+
print(f"\nπ Aggressive GRPO Chess Agent")
|
| 402 |
+
print(f" Envs:{B} | Groups:{num_groups}ΓG:{G} | Device:{self.device.upper()} | "
|
| 403 |
+
f"VRAM:{vram_gb:.1f}GB")
|
| 404 |
+
print(f" Reward: capture(0-0.3)+check(0.3)+checkmate_speed(1.0-1.5)"
|
| 405 |
+
f"+draw_penalty(-0.5)+time(-0.003/step)")
|
| 406 |
+
print(f" gamma:{gamma} | entropy:{CONFIG['entropy_coef']} | "
|
| 407 |
+
f"lr:{CONFIG['learning_rate']}")
|
| 408 |
+
|
| 409 |
+
for iteration in range(self.start_iter, num_iterations):
|
| 410 |
+
|
| 411 |
+
# ββ Runtime guard ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 412 |
+
elapsed = time.time() - t_start
|
| 413 |
+
if elapsed > max_rt:
|
| 414 |
+
print(f"\nβ± {elapsed/3600:.2f}h reached. Saving & downloadingβ¦")
|
| 415 |
+
self.save_checkpoint(iteration)
|
| 416 |
+
self.plot_metrics()
|
| 417 |
+
auto_download(CONFIG["checkpoint_dir"])
|
| 418 |
+
break
|
| 419 |
+
|
| 420 |
+
iter_start = time.time()
|
| 421 |
+
|
| 422 |
+
# Zero buffers in-place (no reallocation)
|
| 423 |
+
states_buf.zero_(); actions_buf.zero_(); logprobs_buf.zero_()
|
| 424 |
+
values_buf.zero_(); rewards_buf.zero_()
|
| 425 |
+
dones_buf.fill_(False); active_buf.fill_(False)
|
| 426 |
+
|
| 427 |
+
# ββ GRPO: each group of G envs shares an opening position ββββββ
|
| 428 |
+
fens = [get_opening_position(CONFIG["opening_max_moves"]).fen()
|
| 429 |
+
for _ in range(num_groups)]
|
| 430 |
+
envs: list[chess.Board] = []
|
| 431 |
+
for gi in range(num_groups):
|
| 432 |
+
for _ in range(G):
|
| 433 |
+
envs.append(chess.Board(fens[gi]))
|
| 434 |
+
|
| 435 |
+
active = np.ones(B, dtype=bool)
|
| 436 |
+
game_lengths = np.zeros(B, dtype=np.int32)
|
| 437 |
+
|
| 438 |
+
# Per-iteration attack metrics
|
| 439 |
+
white_wins = black_wins = draws_count = 0
|
| 440 |
+
total_checks = total_captures = 0
|
| 441 |
+
|
| 442 |
+
# ββ PHASE 1: ROLLOUT βββββββββββββββββββββββββββββββββββββββββββ
|
| 443 |
+
for t in range(max_steps):
|
| 444 |
+
if not active.any(): break
|
| 445 |
+
|
| 446 |
+
populate_states_fast(envs, active, bbs_np, meta_np)
|
| 447 |
+
|
| 448 |
+
# Bit-unpack bitboards β int8 state tensor (no float copy)
|
| 449 |
+
bbs_t = torch.as_tensor(bbs_np, dtype=torch.int64, device=self.device)
|
| 450 |
+
unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).to(torch.int8)
|
| 451 |
+
meta_t = torch.as_tensor(meta_np, dtype=torch.float32, device=self.device)
|
| 452 |
+
|
| 453 |
+
# Pack into int8 buffer (scale float meta to [-127,127])
|
| 454 |
+
states_buf[t, :, :12, :, :] = unpacked.view(B, 12, 8, 8)
|
| 455 |
+
states_buf[t, :, 12, :, :] = (meta_t[:, 0] * 127).clamp(-127, 127) \
|
| 456 |
+
.to(torch.int8).view(B, 1, 1).expand(B, 8, 8)
|
| 457 |
+
states_buf[t, :, 13, :, :] = (meta_t[:, 1] * 127).clamp(0, 127) \
|
| 458 |
+
.to(torch.int8).view(B, 1, 1).expand(B, 8, 8)
|
| 459 |
+
states_buf[t, :, 13, 0, 1]= (meta_t[:, 2] * 127).clamp(0, 127).to(torch.int8)
|
| 460 |
+
active_buf[t] = torch.as_tensor(active, dtype=torch.bool, device=self.device)
|
| 461 |
+
|
| 462 |
+
# Normalize int8βfloat32 for forward pass
|
| 463 |
+
model_input = states_buf[t].to(
|
| 464 |
+
dtype=torch.float32, memory_format=torch.channels_last) / 127.0
|
| 465 |
+
|
| 466 |
+
self.model.eval()
|
| 467 |
+
with torch.no_grad(), torch.amp.autocast('cuda'):
|
| 468 |
+
logits, values = self.model(model_input)
|
| 469 |
+
|
| 470 |
+
masks_np, legal_moves_list = get_legal_masks(envs, active)
|
| 471 |
+
masks_t = torch.as_tensor(masks_np, dtype=torch.bool, device=self.device)
|
| 472 |
+
logits = logits.float()
|
| 473 |
+
logits = torch.where(masks_t, logits,
|
| 474 |
+
torch.tensor(-60000.0, device=self.device))
|
| 475 |
+
no_legal = ~masks_t.any(dim=-1, keepdim=True)
|
| 476 |
+
logits.masked_fill_(no_legal, 0.0)
|
| 477 |
+
|
| 478 |
+
probs = F.softmax(logits, dim=-1)
|
| 479 |
+
dist = torch.distributions.Categorical(probs)
|
| 480 |
+
actions = dist.sample()
|
| 481 |
+
|
| 482 |
+
actions_buf[t] = actions.to(torch.int16)
|
| 483 |
+
logprobs_buf[t] = dist.log_prob(actions)
|
| 484 |
+
values_buf[t] = values.squeeze(-1)
|
| 485 |
+
|
| 486 |
+
actions_cpu = actions.cpu().numpy()
|
| 487 |
+
|
| 488 |
+
for b in range(B):
|
| 489 |
+
if not active[b]: continue
|
| 490 |
+
|
| 491 |
+
move_uci = ACTION_MAPPER.idx_to_move[actions_cpu[b]]
|
| 492 |
+
move = chess.Move.from_uci(move_uci)
|
| 493 |
+
if move not in legal_moves_list[b]:
|
| 494 |
+
move = random.choice(legal_moves_list[b])
|
| 495 |
+
|
| 496 |
+
board = envs[b]
|
| 497 |
+
mover_is_white = (board.turn == chess.WHITE)
|
| 498 |
+
sign = 1.0 if mover_is_white else -1.0
|
| 499 |
+
|
| 500 |
+
# ββ Reward: pre-push components βββββββββββββββββββββ
|
| 501 |
+
r = -0.003 * sign # time penalty (per-mover, white-perspective)
|
| 502 |
+
|
| 503 |
+
if board.is_capture(move):
|
| 504 |
+
if board.is_en_passant(move):
|
| 505 |
+
cap_val = 1.0
|
| 506 |
+
else:
|
| 507 |
+
cp = board.piece_at(move.to_square)
|
| 508 |
+
cap_val = PIECE_VAL.get(cp.piece_type, 0.0) if cp else 0.0
|
| 509 |
+
r += sign * (cap_val / 9.0) * 0.3 # [0, 0.3]
|
| 510 |
+
total_captures += 1
|
| 511 |
+
|
| 512 |
+
if move.promotion in (chess.QUEEN, chess.ROOK):
|
| 513 |
+
r += sign * 0.15 # aggressive promotion
|
| 514 |
+
|
| 515 |
+
board.push(move)
|
| 516 |
+
game_lengths[b] += 1
|
| 517 |
+
|
| 518 |
+
# ββ Reward: post-push components ββββββββββββββββββββ
|
| 519 |
+
if board.is_check():
|
| 520 |
+
r += sign * 0.3 # gave check
|
| 521 |
+
total_checks += 1
|
| 522 |
+
|
| 523 |
+
if board.is_game_over():
|
| 524 |
+
if board.is_checkmate():
|
| 525 |
+
# Mover delivered checkmate
|
| 526 |
+
speed_bonus = 0.5 * math.exp(-game_lengths[b] / 20.0)
|
| 527 |
+
r += sign * (1.0 + speed_bonus) # ~1.0-1.5
|
| 528 |
+
if mover_is_white: white_wins += 1
|
| 529 |
+
else: black_wins += 1
|
| 530 |
+
else:
|
| 531 |
+
# Draw (stalemate / 50-move / repetition / insufficient material)
|
| 532 |
+
r -= 0.5 # flat penalty from white's perspective β attack to WIN
|
| 533 |
+
draws_count += 1
|
| 534 |
+
dones_buf[t, b] = True
|
| 535 |
+
active[b] = False
|
| 536 |
+
|
| 537 |
+
rewards_buf[t, b] = r
|
| 538 |
+
# end per-env loop
|
| 539 |
+
# end rollout
|
| 540 |
+
|
| 541 |
+
# ββ PHASE 2: VECTORIZED RETURNS ββββββββββββββββββββββββββββββββ
|
| 542 |
+
returns = torch.zeros(B, dtype=torch.float32, device=self.device)
|
| 543 |
+
returns_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
|
| 544 |
+
not_done_f = (~dones_buf).float()
|
| 545 |
+
for step in reversed(range(max_steps)):
|
| 546 |
+
returns = rewards_buf[step] + gamma * returns * not_done_f[step]
|
| 547 |
+
returns_buf[step]= returns
|
| 548 |
+
|
| 549 |
+
# ββ PHASE 3: GRPO GROUP-WISE ADVANTAGE NORMALIZATION βββββββββββ
|
| 550 |
+
# advantages shape [max_steps, B]
|
| 551 |
+
adv_raw = returns_buf - values_buf
|
| 552 |
+
active_f = active_buf.float()
|
| 553 |
+
|
| 554 |
+
# Reshape to [max_steps, num_groups, G] and normalize within each group
|
| 555 |
+
adv_3d = adv_raw.view(max_steps, num_groups, G)
|
| 556 |
+
act_3d = active_f.view(max_steps, num_groups, G)
|
| 557 |
+
|
| 558 |
+
g_count = act_3d.sum(dim=[0, 2]).clamp(min=1.0) # [num_groups]
|
| 559 |
+
g_mean = (adv_3d * act_3d).sum(dim=[0, 2]) / g_count # [num_groups]
|
| 560 |
+
g_sq_diff = ((adv_3d - g_mean.view(1, num_groups, 1)) ** 2
|
| 561 |
+
* act_3d).sum(dim=[0, 2])
|
| 562 |
+
g_std = (g_sq_diff / g_count).sqrt().clamp(min=1e-8) # [num_groups]
|
| 563 |
+
adv_3d = (adv_3d - g_mean.view(1, num_groups, 1)) / \
|
| 564 |
+
g_std.view(1, num_groups, 1)
|
| 565 |
+
adv_norm = adv_3d.view(max_steps, B)
|
| 566 |
+
|
| 567 |
+
# Flatten, filter to active steps only
|
| 568 |
+
valid_mask = active_buf.view(-1)
|
| 569 |
+
flat_states = (states_buf.view(-1, 14, 8, 8)[valid_mask]
|
| 570 |
+
.to(torch.float32, memory_format=torch.channels_last)
|
| 571 |
+
.div_(127.0))
|
| 572 |
+
flat_actions = actions_buf.view(-1)[valid_mask].to(torch.int64)
|
| 573 |
+
flat_old_lp = logprobs_buf.view(-1)[valid_mask]
|
| 574 |
+
flat_returns = returns_buf.view(-1)[valid_mask]
|
| 575 |
+
flat_advantages = adv_norm.view(-1)[valid_mask]
|
| 576 |
+
|
| 577 |
+
dataset_size = flat_states.size(0)
|
| 578 |
+
if dataset_size < 100:
|
| 579 |
+
continue # skip degenerate rollout (all games ended instantly)
|
| 580 |
+
|
| 581 |
+
# ββ PHASE 4: PPO OPTIMIZATION ββββββββββββββββββββββββββββββββββ
|
| 582 |
+
self.model.train()
|
| 583 |
+
total_p_loss = total_v_loss = 0.0
|
| 584 |
+
num_updates = 0
|
| 585 |
+
mb_size = CONFIG["mini_batch_size"]
|
| 586 |
+
|
| 587 |
+
for _ in range(CONFIG["ppo_epochs"]):
|
| 588 |
+
perm = torch.randperm(dataset_size, device=self.device)
|
| 589 |
+
for start in range(0, dataset_size, mb_size):
|
| 590 |
+
mb = perm[start: start + mb_size]
|
| 591 |
+
with torch.amp.autocast('cuda'):
|
| 592 |
+
new_logits, new_vals = self.model(flat_states[mb])
|
| 593 |
+
new_dist = torch.distributions.Categorical(logits=new_logits)
|
| 594 |
+
new_lp = new_dist.log_prob(flat_actions[mb])
|
| 595 |
+
ratio = torch.exp(new_lp - flat_old_lp[mb])
|
| 596 |
+
adv = flat_advantages[mb]
|
| 597 |
+
surr1 = ratio * adv
|
| 598 |
+
surr2 = torch.clamp(
|
| 599 |
+
ratio,
|
| 600 |
+
1.0 - CONFIG["clip_epsilon"],
|
| 601 |
+
1.0 + CONFIG["clip_epsilon"],
|
| 602 |
+
) * adv
|
| 603 |
+
p_loss = -torch.min(surr1, surr2).mean()
|
| 604 |
+
v_loss = F.mse_loss(new_vals.squeeze(-1), flat_returns[mb])
|
| 605 |
+
entropy = new_dist.entropy().mean()
|
| 606 |
+
loss = (p_loss
|
| 607 |
+
+ CONFIG["value_coef"] * v_loss
|
| 608 |
+
- CONFIG["entropy_coef"] * entropy)
|
| 609 |
+
|
| 610 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 611 |
+
self.scaler.scale(loss).backward()
|
| 612 |
+
self.scaler.unscale_(self.optimizer)
|
| 613 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 614 |
+
self.scaler.step(self.optimizer)
|
| 615 |
+
self.scaler.update()
|
| 616 |
+
|
| 617 |
+
total_p_loss += p_loss.item()
|
| 618 |
+
total_v_loss += v_loss.item()
|
| 619 |
+
num_updates += 1
|
| 620 |
+
|
| 621 |
+
# ββ PHASE 5: METRICS & LOGGING ββββββββββββββββββββββββββββββββ
|
| 622 |
+
done_count = white_wins + black_wins + draws_count
|
| 623 |
+
win_rate = white_wins / max(done_count, 1)
|
| 624 |
+
draw_rate = draws_count / max(done_count, 1)
|
| 625 |
+
active_steps = int(active_buf.sum().item())
|
| 626 |
+
check_rate = total_checks / max(active_steps, 1)
|
| 627 |
+
capture_rate = total_captures / max(active_steps, 1)
|
| 628 |
+
avg_game_len = float(game_lengths.mean())
|
| 629 |
+
fps = dataset_size / max(time.time() - iter_start, 1e-3)
|
| 630 |
+
|
| 631 |
+
if (iteration + 1) % CONFIG["log_interval"] == 0:
|
| 632 |
+
vram_alloc = (torch.cuda.memory_allocated() / 1e9
|
| 633 |
+
if torch.cuda.is_available() else 0.0)
|
| 634 |
+
vram_res = (torch.cuda.memory_reserved() / 1e9
|
| 635 |
+
if torch.cuda.is_available() else 0.0)
|
| 636 |
+
print(
|
| 637 |
+
f"[{iteration+1:05d}] "
|
| 638 |
+
f"P:{total_p_loss/max(1,num_updates):.4f} "
|
| 639 |
+
f"V:{total_v_loss/max(1,num_updates):.4f} | "
|
| 640 |
+
f"W:{win_rate:.3f} D:{draw_rate:.3f} "
|
| 641 |
+
f"Chk:{check_rate:.4f} Cap:{capture_rate:.4f} "
|
| 642 |
+
f"Len:{avg_game_len:.1f} | "
|
| 643 |
+
f"ELO:{self.elo_tracker.elo:.0f} | "
|
| 644 |
+
f"FPS:{fps:.0f} | "
|
| 645 |
+
f"VRAM:{vram_alloc:.2f}/{vram_res:.2f}GB"
|
| 646 |
+
)
|
| 647 |
+
with open(self.log_file, "a", newline="") as f:
|
| 648 |
+
csv.writer(f).writerow([
|
| 649 |
+
iteration + 1,
|
| 650 |
+
total_p_loss / max(1, num_updates),
|
| 651 |
+
total_v_loss / max(1, num_updates),
|
| 652 |
+
flat_returns.mean().item(),
|
| 653 |
+
fps, win_rate, draw_rate,
|
| 654 |
+
check_rate, capture_rate, avg_game_len,
|
| 655 |
+
])
|
| 656 |
+
|
| 657 |
+
# Save best checkpoint when win_rate improves
|
| 658 |
+
if win_rate > self.best_win_rate:
|
| 659 |
+
self.best_win_rate = win_rate
|
| 660 |
+
self.save_checkpoint(iteration + 1, is_best=True)
|
| 661 |
+
|
| 662 |
+
if (iteration + 1) % CONFIG["save_interval"] == 0:
|
| 663 |
+
self.save_checkpoint(iteration + 1)
|
| 664 |
+
self.plot_metrics()
|
| 665 |
+
|
| 666 |
+
# ELO evaluation
|
| 667 |
+
if (iteration + 1) % CONFIG["elo_eval_interval"] == 0:
|
| 668 |
+
elo_before = self.elo_tracker.elo
|
| 669 |
+
ew, ed, el = self.evaluate_elo(CONFIG["elo_eval_games"])
|
| 670 |
+
print(
|
| 671 |
+
f" [ELO eval] {elo_before:.0f} β {self.elo_tracker.elo:.0f} | "
|
| 672 |
+
f"W:{ew} D:{ed} L:{el} vs random({RANDOM_BASELINE_ELO})"
|
| 673 |
+
)
|
| 674 |
+
with open(self.elo_log_file, "a", newline="") as f:
|
| 675 |
+
csv.writer(f).writerow(
|
| 676 |
+
[iteration + 1, self.elo_tracker.elo, ew, ed, el])
|
| 677 |
+
self.plot_metrics()
|
| 678 |
+
|
| 679 |
+
# Aggressive cache reclaim (free fragmented blocks, not pinned allocs)
|
| 680 |
+
torch.cuda.empty_cache()
|
| 681 |
+
|
| 682 |
+
# ββ Plotting βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 683 |
+
def plot_metrics(self) -> None:
|
| 684 |
+
if not os.path.exists(self.log_file): return
|
| 685 |
+
df = pd.read_csv(self.log_file)
|
| 686 |
+
if len(df) < 2: return
|
| 687 |
+
|
| 688 |
+
elo_df = None
|
| 689 |
+
if os.path.exists(self.elo_log_file):
|
| 690 |
+
elo_df = pd.read_csv(self.elo_log_file)
|
| 691 |
+
|
| 692 |
+
fig, axs = plt.subplots(3, 2, figsize=(14, 12))
|
| 693 |
+
fig.suptitle("Aggressive GRPO Chess Agent β Training Dashboard", fontsize=14)
|
| 694 |
+
|
| 695 |
+
# Row 0: Losses
|
| 696 |
+
axs[0, 0].plot(df['iteration'], df['p_loss'], color='steelblue', linewidth=1.2)
|
| 697 |
+
axs[0, 0].set_title('Policy Loss'); axs[0, 0].set_xlabel('Iteration')
|
| 698 |
+
|
| 699 |
+
axs[0, 1].plot(df['iteration'], df['v_loss'], color='tomato', linewidth=1.2)
|
| 700 |
+
axs[0, 1].set_title('Value Loss'); axs[0, 1].set_xlabel('Iteration')
|
| 701 |
+
|
| 702 |
+
# Row 1: Outcomes
|
| 703 |
+
axs[1, 0].plot(df['iteration'], df['win_rate'], label='Win', color='green')
|
| 704 |
+
axs[1, 0].plot(df['iteration'], df['draw_rate'], label='Draw', color='orange')
|
| 705 |
+
axs[1, 0].set_title('Outcomes (White perspective)')
|
| 706 |
+
axs[1, 0].legend(); axs[1, 0].set_xlabel('Iteration')
|
| 707 |
+
|
| 708 |
+
# Row 1: Attack metrics
|
| 709 |
+
axs[1, 1].plot(df['iteration'], df['check_rate'], label='Check/step', color='purple')
|
| 710 |
+
axs[1, 1].plot(df['iteration'], df['capture_rate'], label='Capture/step', color='darkorange')
|
| 711 |
+
axs[1, 1].set_title('Attack Metrics (β = more aggressive)')
|
| 712 |
+
axs[1, 1].legend(); axs[1, 1].set_xlabel('Iteration')
|
| 713 |
+
|
| 714 |
+
# Row 2: ELO Rating
|
| 715 |
+
if elo_df is not None and len(elo_df) > 0:
|
| 716 |
+
axs[2, 0].plot(elo_df['iteration'], elo_df['elo'],
|
| 717 |
+
color='gold', linewidth=2.0, label='Agent ELO')
|
| 718 |
+
axs[2, 0].axhline(RANDOM_BASELINE_ELO, linestyle='--',
|
| 719 |
+
color='gray', alpha=0.8, label=f'Random ({RANDOM_BASELINE_ELO})')
|
| 720 |
+
axs[2, 0].axhline(1200, linestyle=':', color='lightblue',
|
| 721 |
+
alpha=0.6, label='Start (1200)')
|
| 722 |
+
axs[2, 0].fill_between(elo_df['iteration'], RANDOM_BASELINE_ELO,
|
| 723 |
+
elo_df['elo'], alpha=0.15, color='gold')
|
| 724 |
+
axs[2, 0].set_title('ELO Rating vs Random Baseline')
|
| 725 |
+
axs[2, 0].legend(); axs[2, 0].set_xlabel('Iteration')
|
| 726 |
+
else:
|
| 727 |
+
axs[2, 0].text(0.5, 0.5, f'ELO eval every {CONFIG["elo_eval_interval"]} iters',
|
| 728 |
+
ha='center', va='center', transform=axs[2, 0].transAxes,
|
| 729 |
+
color='gray', fontsize=11)
|
| 730 |
+
axs[2, 0].set_title('ELO Rating (pending)')
|
| 731 |
+
|
| 732 |
+
# Row 2: Average game length
|
| 733 |
+
axs[2, 1].plot(df['iteration'], df['avg_game_len'], color='teal', linewidth=1.2)
|
| 734 |
+
axs[2, 1].set_title('Avg Game Length (β = faster checkmates)')
|
| 735 |
+
axs[2, 1].set_xlabel('Iteration')
|
| 736 |
+
|
| 737 |
+
for ax in axs.flat:
|
| 738 |
+
ax.grid(True, alpha=0.25)
|
| 739 |
+
|
| 740 |
+
plt.tight_layout()
|
| 741 |
+
out = os.path.join(CONFIG["checkpoint_dir"], "training_performance.png")
|
| 742 |
+
plt.savefig(out, dpi=100, bbox_inches='tight')
|
| 743 |
+
plt.close(fig)
|
| 744 |
+
print(f" [Plot] saved β {out}")
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
# ββ Entry Point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 748 |
+
if __name__ == "__main__":
|
| 749 |
+
parser = argparse.ArgumentParser(
|
| 750 |
+
description="Aggressive GRPO Chess Agent (T4/Colab)")
|
| 751 |
+
parser.add_argument("--iterations", type=int, default=10000,
|
| 752 |
+
help="Total training iterations")
|
| 753 |
+
parser.add_argument("--test-batch", action="store_true",
|
| 754 |
+
help="Run 2 iterations for smoke-test")
|
| 755 |
+
args, _ = parser.parse_known_args()
|
| 756 |
+
|
| 757 |
+
torch.manual_seed(CONFIG["seed"])
|
| 758 |
+
np.random.seed(CONFIG["seed"])
|
| 759 |
+
random.seed(CONFIG["seed"])
|
| 760 |
+
|
| 761 |
+
# Print VRAM summary at startup
|
| 762 |
+
if torch.cuda.is_available():
|
| 763 |
+
props = torch.cuda.get_device_properties(0)
|
| 764 |
+
print(f"GPU: {props.name} | VRAM: {props.total_memory/1e9:.1f}GB | "
|
| 765 |
+
f"SM: {props.multi_processor_count} | "
|
| 766 |
+
f"Compute: {props.major}.{props.minor}")
|
| 767 |
+
|
| 768 |
+
trainer = GRPOTrainer()
|
| 769 |
+
trainer.train(2 if args.test_batch else args.iterations)
|