| |
| |
|
|
| """ |
| Aggressive GRPO Chess Agent β T4/Colab Optimized |
| """ |
|
|
| import os, sys, csv, time, math, shutil, argparse, random |
| import numpy as np |
| import pandas as pd |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
| try: |
| import chess |
| except ImportError: |
| os.system("pip install -q chess") |
| import chess |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| if hasattr(torch, 'set_float32_matmul_precision'): |
| torch.set_float32_matmul_precision('high') |
|
|
| |
| PIECE_VAL = { |
| chess.PAWN: 1.0, chess.KNIGHT: 3.0, chess.BISHOP: 3.2, |
| chess.ROOK: 5.0, chess.QUEEN: 9.0, chess.KING: 0.0, |
| } |
| RANDOM_BASELINE_ELO = 800 |
|
|
| CONFIG = { |
| "num_envs": 256, |
| "grpo_group_size": 8, |
| "ppo_epochs": 3, |
| "mini_batch_size": 4096, |
| "learning_rate": 2e-4, |
| "weight_decay": 1e-4, |
| "gamma": 0.98, |
| "clip_epsilon": 0.15, |
| "entropy_coef": 0.02, |
| "value_coef": 0.5, |
| "max_steps": 100, |
| "opening_max_moves": 10, |
| "checkpoint_dir": "./checkpoints", |
| "save_interval": 50, |
| "log_interval": 1, |
| "elo_eval_interval": 100, |
| "elo_eval_games": 32, |
| "max_runtime_hours": 4.5, |
| "device": "cuda" if torch.cuda.is_available() else "cpu", |
| "seed": 42, |
| } |
|
|
| |
| class ActionMapper: |
| __slots__ = ['move_to_idx', 'idx_to_move', 'num_actions'] |
| def __init__(self): |
| self.move_to_idx: dict[str, int] = {} |
| self.idx_to_move: list[str] = [] |
| idx = 0 |
| for f in range(64): |
| for t in range(64): |
| if f == t: continue |
| uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t] |
| self.move_to_idx[uci] = idx |
| self.idx_to_move.append(uci) |
| idx += 1 |
| if chess.square_rank(f) in (1, 6) and \ |
| abs(chess.square_file(f) - chess.square_file(t)) <= 1: |
| for promo in "nbrq": |
| puci = uci + promo |
| self.move_to_idx[puci] = idx |
| self.idx_to_move.append(puci) |
| idx += 1 |
| self.num_actions = idx |
|
|
| ACTION_MAPPER = ActionMapper() |
|
|
| |
| def populate_states_fast(envs: list, active_mask: np.ndarray, |
| bbs_np: np.ndarray, meta_np: np.ndarray) -> None: |
| """Fill bbs_np [B,12] int64 and meta_np [B,3] float32 for active envs.""" |
| for b in range(len(envs)): |
| if not active_mask[b]: continue |
| env = envs[b] |
| w = env.occupied_co[chess.WHITE] |
| bc = env.occupied_co[chess.BLACK] |
| bbs_np[b, 0] = env.pawns & w; bbs_np[b, 1] = env.knights & w |
| bbs_np[b, 2] = env.bishops & w; bbs_np[b, 3] = env.rooks & w |
| bbs_np[b, 4] = env.queens & w; bbs_np[b, 5] = env.kings & w |
| bbs_np[b, 6] = env.pawns & bc; bbs_np[b, 7] = env.knights & bc |
| bbs_np[b, 8] = env.bishops & bc; bbs_np[b, 9] = env.rooks & bc |
| bbs_np[b, 10] = env.queens & bc; bbs_np[b, 11] = env.kings & bc |
| meta_np[b, 0] = 1.0 if env.turn else -1.0 |
| meta_np[b, 1] = float(env.castling_rights) / 15.0 |
| meta_np[b, 2] = 1.0 if env.ep_square is not None else 0.0 |
|
|
| def get_legal_masks(envs: list, active_mask: np.ndarray): |
| masks = np.zeros((len(envs), ACTION_MAPPER.num_actions), dtype=np.bool_) |
| moves_list = [None] * len(envs) |
| for b in range(len(envs)): |
| if not active_mask[b]: continue |
| legal = list(envs[b].legal_moves) |
| moves_list[b] = legal |
| for m in legal: |
| masks[b, ACTION_MAPPER.move_to_idx[m.uci()]] = True |
| return masks, moves_list |
|
|
| |
| class ChessNet(nn.Module): |
| def __init__(self, res_blocks: int = 8, channels: int = 128): |
| super().__init__() |
| self.conv_in = nn.Conv2d(14, channels, 3, padding=1, bias=False) |
| self.bn_in = nn.BatchNorm2d(channels) |
| self.res_blocks = nn.ModuleList([ |
| nn.Sequential( |
| nn.Conv2d(channels, channels, 3, padding=1, bias=False), |
| nn.BatchNorm2d(channels), nn.ReLU(inplace=True), |
| nn.Conv2d(channels, channels, 3, padding=1, bias=False), |
| nn.BatchNorm2d(channels), |
| ) for _ in range(res_blocks) |
| ]) |
| self.policy_head = nn.Sequential( |
| nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), nn.Flatten(), |
| nn.Linear(32 * 64, ACTION_MAPPER.num_actions), |
| ) |
| |
| self.value_head = nn.Sequential( |
| nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), nn.Flatten(), |
| nn.Linear(32 * 64, 256), nn.ReLU(inplace=True), |
| nn.Linear(256, 1), |
| ) |
|
|
| def forward(self, x): |
| x = F.relu(self.bn_in(self.conv_in(x)), inplace=True) |
| for blk in self.res_blocks: |
| x = F.relu(x + blk(x), inplace=True) |
| return self.policy_head(x), self.value_head(x) |
|
|
| |
| class ELOTracker: |
| def __init__(self, initial_elo: float = 1200.0, K: float = 32.0): |
| self.elo = initial_elo |
| self.K = K |
|
|
| def expected(self, opp_elo: float) -> float: |
| return 1.0 / (1.0 + 10.0 ** ((opp_elo - self.elo) / 400.0)) |
|
|
| def update(self, score: float, opp_elo: float) -> None: |
| self.elo += self.K * (score - self.expected(opp_elo)) |
|
|
| |
| def get_opening_position(max_moves: int = 10) -> chess.Board: |
| """Play 0..max_moves random half-moves from start for GRPO diversity.""" |
| board = chess.Board() |
| for _ in range(random.randint(0, max_moves)): |
| if board.is_game_over(): break |
| board.push(random.choice(list(board.legal_moves))) |
| return chess.Board(board.fen()) |
|
|
| |
| def auto_download(checkpoint_dir: str) -> None: |
| """Sync to Google Drive if mounted, else trigger browser downloads.""" |
| try: |
| from google.colab import files as _cf |
| drive_dst = '/content/drive/MyDrive/chess_agent' |
| if os.path.exists('/content/drive/MyDrive'): |
| os.makedirs(drive_dst, exist_ok=True) |
| shutil.copytree(checkpoint_dir, drive_dst, dirs_exist_ok=True) |
| print(f"[AutoSave] Synced β {drive_dst}") |
| else: |
| for fname in ['best.pt', 'latest.pt', 'training_log.csv', |
| 'elo_log.csv', 'training_performance.png']: |
| fpath = os.path.join(checkpoint_dir, fname) |
| if os.path.exists(fpath): |
| _cf.download(fpath) |
| print(f"[AutoSave] Downloaded {fname}") |
| except Exception as e: |
| print(f"[AutoSave] {e}") |
|
|
| |
| class GRPOTrainer: |
|
|
| def __init__(self): |
| self.device = CONFIG["device"] |
|
|
| _model = ChessNet(res_blocks=8, channels=128) |
| _model = _model.to(self.device).to(memory_format=torch.channels_last) |
| try: |
| print("Compiling model (reduce-overhead)β¦") |
| self.model = torch.compile(_model, mode="reduce-overhead") |
| except Exception: |
| self.model = _model |
|
|
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=CONFIG["learning_rate"], |
| weight_decay=CONFIG["weight_decay"], |
| fused=torch.cuda.is_available(), |
| ) |
| self.scaler = torch.amp.GradScaler('cuda') |
| self.start_iter = 0 |
| self.best_win_rate = 0.0 |
| self.elo_tracker = ELOTracker() |
|
|
| |
| self.shifts = torch.arange(64, dtype=torch.int64, |
| device=self.device).view(1, 1, 64) |
|
|
| os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True) |
| self.log_file = os.path.join(CONFIG["checkpoint_dir"], "training_log.csv") |
| self.elo_log_file = os.path.join(CONFIG["checkpoint_dir"], "elo_log.csv") |
|
|
| if not os.path.exists(self.log_file): |
| with open(self.log_file, "w", newline="") as f: |
| csv.writer(f).writerow([ |
| "iteration", "p_loss", "v_loss", "v_mean", "fps", |
| "win_rate", "draw_rate", "check_rate", "capture_rate", "avg_game_len", |
| ]) |
| if not os.path.exists(self.elo_log_file): |
| with open(self.elo_log_file, "w", newline="") as f: |
| csv.writer(f).writerow( |
| ["iteration", "elo", "eval_wins", "eval_draws", "eval_losses"]) |
|
|
| self._init_checkpointing() |
|
|
| |
| def _init_checkpointing(self) -> None: |
| latest = os.path.join(CONFIG["checkpoint_dir"], "latest.pt") |
| if not os.path.exists(latest): |
| return |
| try: |
| ckpt = torch.load(latest, map_location=self.device, weights_only=False) |
| sd = ckpt['model_state_dict'] |
| |
| loaded = False |
| for attempt in [ |
| sd, |
| {k.replace('_orig_mod.', ''): v for k, v in sd.items()}, |
| {'_orig_mod.' + k: v for k, v in sd.items()}, |
| ]: |
| try: |
| self.model.load_state_dict(attempt); loaded = True; break |
| except RuntimeError: |
| continue |
| if not loaded: |
| raise RuntimeError("All state dict key variants failed.") |
| self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) |
| self.scaler.load_state_dict(ckpt['scaler_state_dict']) |
| self.start_iter = ckpt.get('iteration', 0) + 1 |
| self.elo_tracker.elo = ckpt.get('elo', 1200.0) |
| self.best_win_rate = ckpt.get('best_win_rate', 0.0) |
| print(f"Resumed from iter {self.start_iter} | " |
| f"ELO {self.elo_tracker.elo:.0f} | best_win {self.best_win_rate:.3f}") |
| except Exception as e: |
| print(f"Checkpoint load failed ({e}). Starting fresh.") |
|
|
| def save_checkpoint(self, iteration: int, is_best: bool = False) -> None: |
| ckpt = { |
| 'iteration': iteration, |
| 'model_state_dict': self.model.state_dict(), |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'scaler_state_dict': self.scaler.state_dict(), |
| 'elo': self.elo_tracker.elo, |
| 'best_win_rate': self.best_win_rate, |
| 'config': CONFIG, |
| } |
| cdir = CONFIG["checkpoint_dir"] |
| path = os.path.join(cdir, f"iter_{iteration:04d}.pt") |
| |
| torch.save(ckpt, path + ".tmp"); os.replace(path + ".tmp", path) |
| latest = os.path.join(cdir, "latest.pt") |
| shutil.copy2(path, latest + ".tmp"); os.replace(latest + ".tmp", latest) |
| if is_best: |
| best = os.path.join(cdir, "best.pt") |
| shutil.copy2(path, best + ".tmp"); os.replace(best + ".tmp", best) |
|
|
| |
| def _elo_game_done(self, board: chess.Board, idx: int, agent_color, |
| scores: np.ndarray, active: np.ndarray) -> None: |
| if board.is_game_over(): |
| res = board.result() |
| if (res == "1-0" and agent_color == chess.WHITE) or \ |
| (res == "0-1" and agent_color == chess.BLACK): |
| scores[idx] = 1.0 |
| elif res == "1/2-1/2": |
| scores[idx] = 0.5 |
| else: |
| scores[idx] = 0.0 |
| active[idx] = False |
|
|
| def evaluate_elo(self, n_games: int = 32, max_ply: int = 200) -> tuple: |
| """ |
| Play n_games vs random opponent (batched GPU for agent moves). |
| Half games as White, half as Black. |
| Returns (wins, draws, losses) from agent's perspective. |
| """ |
| self.model.eval() |
| boards = [chess.Board() for _ in range(n_games)] |
| agent_colors = [chess.WHITE if i % 2 == 0 else chess.BLACK |
| for i in range(n_games)] |
| scores = np.full(n_games, 0.5, dtype=np.float32) |
| active = np.ones(n_games, dtype=bool) |
| bbs_sub = np.zeros((n_games, 12), dtype=np.int64) |
| meta_sub= np.zeros((n_games, 3), dtype=np.float32) |
|
|
| for _ in range(max_ply): |
| if not active.any(): break |
|
|
| |
| for i in [i for i in range(n_games) |
| if active[i] and boards[i].turn != agent_colors[i]]: |
| legal = list(boards[i].legal_moves) |
| if legal: boards[i].push(random.choice(legal)) |
| self._elo_game_done(boards[i], i, agent_colors[i], scores, active) |
|
|
| |
| ag_idx = [i for i in range(n_games) |
| if active[i] and boards[i].turn == agent_colors[i]] |
| if not ag_idx: |
| continue |
|
|
| n = len(ag_idx) |
| sub = [boards[i] for i in ag_idx] |
| act_sub = np.ones(n, dtype=bool) |
| populate_states_fast(sub, act_sub, bbs_sub[:n], meta_sub[:n]) |
|
|
| bbs_t = torch.tensor(bbs_sub[:n], dtype=torch.int64, device=self.device) |
| unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).float().view(n, 12, 8, 8) |
| state = torch.zeros(n, 14, 8, 8, device=self.device, dtype=torch.float32) |
| state[:, :12] = unpacked |
| state[:, 12] = torch.tensor(meta_sub[:n, 0], device=self.device).view(n, 1, 1).expand(n, 8, 8) |
| state[:, 13] = torch.tensor(meta_sub[:n, 1], device=self.device).view(n, 1, 1).expand(n, 8, 8) |
| for lj in range(n): |
| if meta_sub[lj, 2]: |
| state[lj, 13, 0, 1] = float(meta_sub[lj, 2]) |
|
|
| with torch.no_grad(), torch.amp.autocast('cuda'): |
| logits, _ = self.model(state.to(memory_format=torch.channels_last)) |
| logits = logits.float() |
|
|
| masks_np, legal_lists = get_legal_masks(sub, act_sub) |
| masks_t = torch.tensor(masks_np, dtype=torch.bool, device=self.device) |
| logits = torch.where(masks_t, logits, |
| torch.tensor(-60000.0, device=self.device)) |
| best_acts = logits.argmax(dim=-1).cpu().numpy() |
|
|
| for lj, gi in enumerate(ag_idx): |
| if not active[gi]: continue |
| move_uci = ACTION_MAPPER.idx_to_move[best_acts[lj]] |
| move = chess.Move.from_uci(move_uci) |
| legal = legal_lists[lj] or list(boards[gi].legal_moves) |
| if not legal: |
| active[gi] = False; continue |
| if move not in legal: |
| move = random.choice(legal) |
| boards[gi].push(move) |
| self._elo_game_done(boards[gi], gi, agent_colors[gi], scores, active) |
|
|
| wins = int((scores == 1.0).sum()) |
| draws = int((scores == 0.5).sum()) |
| losses = int((scores == 0.0).sum()) |
| for s in scores: |
| self.elo_tracker.update(float(s), RANDOM_BASELINE_ELO) |
| return wins, draws, losses |
|
|
| |
| def train(self, num_iterations: int) -> None: |
| B = CONFIG["num_envs"] |
| max_steps = CONFIG["max_steps"] |
| G = CONFIG["grpo_group_size"] |
| num_groups= B // G |
| gamma = CONFIG["gamma"] |
| t_start = time.time() |
| max_rt = CONFIG["max_runtime_hours"] * 3600.0 |
|
|
| |
| states_buf = torch.zeros((max_steps, B, 14, 8, 8), dtype=torch.int8, device=self.device) |
| actions_buf = torch.zeros((max_steps, B), dtype=torch.int16, device=self.device) |
| logprobs_buf= torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) |
| values_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) |
| rewards_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) |
| dones_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device) |
| active_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device) |
|
|
| bbs_np = np.zeros((B, 12), dtype=np.int64) |
| meta_np = np.zeros((B, 3), dtype=np.float32) |
|
|
| vram_gb = (torch.cuda.get_device_properties(0).total_memory / 1e9 |
| if torch.cuda.is_available() else 0.0) |
| print(f"\nπ Aggressive GRPO Chess Agent") |
| print(f" Envs:{B} | Groups:{num_groups}ΓG:{G} | Device:{self.device.upper()} | " |
| f"VRAM:{vram_gb:.1f}GB") |
| print(f" Reward: capture(0-0.3)+check(0.3)+checkmate_speed(1.0-1.5)" |
| f"+draw_penalty(-0.5)+time(-0.003/step)") |
| print(f" gamma:{gamma} | entropy:{CONFIG['entropy_coef']} | " |
| f"lr:{CONFIG['learning_rate']}") |
|
|
| for iteration in range(self.start_iter, num_iterations): |
|
|
| |
| elapsed = time.time() - t_start |
| if elapsed > max_rt: |
| print(f"\nβ± {elapsed/3600:.2f}h reached. Saving & downloadingβ¦") |
| self.save_checkpoint(iteration) |
| self.plot_metrics() |
| auto_download(CONFIG["checkpoint_dir"]) |
| break |
|
|
| iter_start = time.time() |
|
|
| |
| states_buf.zero_(); actions_buf.zero_(); logprobs_buf.zero_() |
| values_buf.zero_(); rewards_buf.zero_() |
| dones_buf.fill_(False); active_buf.fill_(False) |
|
|
| |
| fens = [get_opening_position(CONFIG["opening_max_moves"]).fen() |
| for _ in range(num_groups)] |
| envs: list[chess.Board] = [] |
| for gi in range(num_groups): |
| for _ in range(G): |
| envs.append(chess.Board(fens[gi])) |
|
|
| active = np.ones(B, dtype=bool) |
| game_lengths = np.zeros(B, dtype=np.int32) |
|
|
| |
| white_wins = black_wins = draws_count = 0 |
| total_checks = total_captures = 0 |
|
|
| |
| for t in range(max_steps): |
| if not active.any(): break |
|
|
| populate_states_fast(envs, active, bbs_np, meta_np) |
|
|
| |
| bbs_t = torch.as_tensor(bbs_np, dtype=torch.int64, device=self.device) |
| unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).to(torch.int8) |
| meta_t = torch.as_tensor(meta_np, dtype=torch.float32, device=self.device) |
|
|
| |
| states_buf[t, :, :12, :, :] = unpacked.view(B, 12, 8, 8) |
| states_buf[t, :, 12, :, :] = (meta_t[:, 0] * 127).clamp(-127, 127) \ |
| .to(torch.int8).view(B, 1, 1).expand(B, 8, 8) |
| states_buf[t, :, 13, :, :] = (meta_t[:, 1] * 127).clamp(0, 127) \ |
| .to(torch.int8).view(B, 1, 1).expand(B, 8, 8) |
| states_buf[t, :, 13, 0, 1]= (meta_t[:, 2] * 127).clamp(0, 127).to(torch.int8) |
| active_buf[t] = torch.as_tensor(active, dtype=torch.bool, device=self.device) |
|
|
| |
| model_input = states_buf[t].to( |
| dtype=torch.float32, memory_format=torch.channels_last) / 127.0 |
|
|
| self.model.eval() |
| with torch.no_grad(), torch.amp.autocast('cuda'): |
| logits, values = self.model(model_input) |
|
|
| masks_np, legal_moves_list = get_legal_masks(envs, active) |
| masks_t = torch.as_tensor(masks_np, dtype=torch.bool, device=self.device) |
| logits = logits.float() |
| logits = torch.where(masks_t, logits, |
| torch.tensor(-60000.0, device=self.device)) |
| no_legal = ~masks_t.any(dim=-1, keepdim=True) |
| logits.masked_fill_(no_legal, 0.0) |
|
|
| probs = F.softmax(logits, dim=-1) |
| dist = torch.distributions.Categorical(probs) |
| actions = dist.sample() |
|
|
| actions_buf[t] = actions.to(torch.int16) |
| logprobs_buf[t] = dist.log_prob(actions) |
| values_buf[t] = values.squeeze(-1) |
|
|
| actions_cpu = actions.cpu().numpy() |
|
|
| for b in range(B): |
| if not active[b]: continue |
|
|
| move_uci = ACTION_MAPPER.idx_to_move[actions_cpu[b]] |
| move = chess.Move.from_uci(move_uci) |
| if move not in legal_moves_list[b]: |
| move = random.choice(legal_moves_list[b]) |
|
|
| board = envs[b] |
| mover_is_white = (board.turn == chess.WHITE) |
| sign = 1.0 if mover_is_white else -1.0 |
|
|
| |
| r = -0.003 * sign |
|
|
| if board.is_capture(move): |
| if board.is_en_passant(move): |
| cap_val = 1.0 |
| else: |
| cp = board.piece_at(move.to_square) |
| cap_val = PIECE_VAL.get(cp.piece_type, 0.0) if cp else 0.0 |
| r += sign * (cap_val / 9.0) * 0.3 |
| total_captures += 1 |
|
|
| if move.promotion in (chess.QUEEN, chess.ROOK): |
| r += sign * 0.15 |
|
|
| board.push(move) |
| game_lengths[b] += 1 |
|
|
| |
| if board.is_check(): |
| r += sign * 0.3 |
| total_checks += 1 |
|
|
| if board.is_game_over(): |
| if board.is_checkmate(): |
| |
| speed_bonus = 0.5 * math.exp(-game_lengths[b] / 20.0) |
| r += sign * (1.0 + speed_bonus) |
| if mover_is_white: white_wins += 1 |
| else: black_wins += 1 |
| else: |
| |
| r -= 0.5 |
| draws_count += 1 |
| dones_buf[t, b] = True |
| active[b] = False |
|
|
| rewards_buf[t, b] = r |
| |
| |
|
|
| |
| returns = torch.zeros(B, dtype=torch.float32, device=self.device) |
| returns_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device) |
| not_done_f = (~dones_buf).float() |
| for step in reversed(range(max_steps)): |
| returns = rewards_buf[step] + gamma * returns * not_done_f[step] |
| returns_buf[step]= returns |
|
|
| |
| |
| adv_raw = returns_buf - values_buf |
| active_f = active_buf.float() |
|
|
| |
| adv_3d = adv_raw.view(max_steps, num_groups, G) |
| act_3d = active_f.view(max_steps, num_groups, G) |
|
|
| g_count = act_3d.sum(dim=[0, 2]).clamp(min=1.0) |
| g_mean = (adv_3d * act_3d).sum(dim=[0, 2]) / g_count |
| g_sq_diff = ((adv_3d - g_mean.view(1, num_groups, 1)) ** 2 |
| * act_3d).sum(dim=[0, 2]) |
| g_std = (g_sq_diff / g_count).sqrt().clamp(min=1e-8) |
| adv_3d = (adv_3d - g_mean.view(1, num_groups, 1)) / \ |
| g_std.view(1, num_groups, 1) |
| adv_norm = adv_3d.view(max_steps, B) |
|
|
| |
| valid_mask = active_buf.view(-1) |
| flat_states = (states_buf.view(-1, 14, 8, 8)[valid_mask] |
| .to(torch.float32, memory_format=torch.channels_last) |
| .div_(127.0)) |
| flat_actions = actions_buf.view(-1)[valid_mask].to(torch.int64) |
| flat_old_lp = logprobs_buf.view(-1)[valid_mask] |
| flat_returns = returns_buf.view(-1)[valid_mask] |
| flat_advantages = adv_norm.view(-1)[valid_mask] |
|
|
| dataset_size = flat_states.size(0) |
| if dataset_size < 100: |
| continue |
|
|
| |
| self.model.train() |
| total_p_loss = total_v_loss = 0.0 |
| num_updates = 0 |
| mb_size = CONFIG["mini_batch_size"] |
|
|
| for _ in range(CONFIG["ppo_epochs"]): |
| perm = torch.randperm(dataset_size, device=self.device) |
| for start in range(0, dataset_size, mb_size): |
| mb = perm[start: start + mb_size] |
| with torch.amp.autocast('cuda'): |
| new_logits, new_vals = self.model(flat_states[mb]) |
| new_dist = torch.distributions.Categorical(logits=new_logits) |
| new_lp = new_dist.log_prob(flat_actions[mb]) |
| ratio = torch.exp(new_lp - flat_old_lp[mb]) |
| adv = flat_advantages[mb] |
| surr1 = ratio * adv |
| surr2 = torch.clamp( |
| ratio, |
| 1.0 - CONFIG["clip_epsilon"], |
| 1.0 + CONFIG["clip_epsilon"], |
| ) * adv |
| p_loss = -torch.min(surr1, surr2).mean() |
| v_loss = F.mse_loss(new_vals.squeeze(-1), flat_returns[mb]) |
| entropy = new_dist.entropy().mean() |
| loss = (p_loss |
| + CONFIG["value_coef"] * v_loss |
| - CONFIG["entropy_coef"] * entropy) |
|
|
| self.optimizer.zero_grad(set_to_none=True) |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
|
|
| total_p_loss += p_loss.item() |
| total_v_loss += v_loss.item() |
| num_updates += 1 |
|
|
| |
| done_count = white_wins + black_wins + draws_count |
| win_rate = white_wins / max(done_count, 1) |
| draw_rate = draws_count / max(done_count, 1) |
| active_steps = int(active_buf.sum().item()) |
| check_rate = total_checks / max(active_steps, 1) |
| capture_rate = total_captures / max(active_steps, 1) |
| avg_game_len = float(game_lengths.mean()) |
| fps = dataset_size / max(time.time() - iter_start, 1e-3) |
|
|
| if (iteration + 1) % CONFIG["log_interval"] == 0: |
| vram_alloc = (torch.cuda.memory_allocated() / 1e9 |
| if torch.cuda.is_available() else 0.0) |
| vram_res = (torch.cuda.memory_reserved() / 1e9 |
| if torch.cuda.is_available() else 0.0) |
| print( |
| f"[{iteration+1:05d}] " |
| f"P:{total_p_loss/max(1,num_updates):.4f} " |
| f"V:{total_v_loss/max(1,num_updates):.4f} | " |
| f"W:{win_rate:.3f} D:{draw_rate:.3f} " |
| f"Chk:{check_rate:.4f} Cap:{capture_rate:.4f} " |
| f"Len:{avg_game_len:.1f} | " |
| f"ELO:{self.elo_tracker.elo:.0f} | " |
| f"FPS:{fps:.0f} | " |
| f"VRAM:{vram_alloc:.2f}/{vram_res:.2f}GB" |
| ) |
| with open(self.log_file, "a", newline="") as f: |
| csv.writer(f).writerow([ |
| iteration + 1, |
| total_p_loss / max(1, num_updates), |
| total_v_loss / max(1, num_updates), |
| flat_returns.mean().item(), |
| fps, win_rate, draw_rate, |
| check_rate, capture_rate, avg_game_len, |
| ]) |
|
|
| |
| if win_rate > self.best_win_rate: |
| self.best_win_rate = win_rate |
| self.save_checkpoint(iteration + 1, is_best=True) |
|
|
| if (iteration + 1) % CONFIG["save_interval"] == 0: |
| self.save_checkpoint(iteration + 1) |
| self.plot_metrics() |
|
|
| |
| if (iteration + 1) % CONFIG["elo_eval_interval"] == 0: |
| elo_before = self.elo_tracker.elo |
| ew, ed, el = self.evaluate_elo(CONFIG["elo_eval_games"]) |
| print( |
| f" [ELO eval] {elo_before:.0f} β {self.elo_tracker.elo:.0f} | " |
| f"W:{ew} D:{ed} L:{el} vs random({RANDOM_BASELINE_ELO})" |
| ) |
| with open(self.elo_log_file, "a", newline="") as f: |
| csv.writer(f).writerow( |
| [iteration + 1, self.elo_tracker.elo, ew, ed, el]) |
| self.plot_metrics() |
|
|
| |
| torch.cuda.empty_cache() |
|
|
| |
| def plot_metrics(self) -> None: |
| if not os.path.exists(self.log_file): return |
| df = pd.read_csv(self.log_file) |
| if len(df) < 2: return |
|
|
| elo_df = None |
| if os.path.exists(self.elo_log_file): |
| elo_df = pd.read_csv(self.elo_log_file) |
|
|
| fig, axs = plt.subplots(3, 2, figsize=(14, 12)) |
| fig.suptitle("Aggressive GRPO Chess Agent β Training Dashboard", fontsize=14) |
|
|
| |
| axs[0, 0].plot(df['iteration'], df['p_loss'], color='steelblue', linewidth=1.2) |
| axs[0, 0].set_title('Policy Loss'); axs[0, 0].set_xlabel('Iteration') |
|
|
| axs[0, 1].plot(df['iteration'], df['v_loss'], color='tomato', linewidth=1.2) |
| axs[0, 1].set_title('Value Loss'); axs[0, 1].set_xlabel('Iteration') |
|
|
| |
| axs[1, 0].plot(df['iteration'], df['win_rate'], label='Win', color='green') |
| axs[1, 0].plot(df['iteration'], df['draw_rate'], label='Draw', color='orange') |
| axs[1, 0].set_title('Outcomes (White perspective)') |
| axs[1, 0].legend(); axs[1, 0].set_xlabel('Iteration') |
|
|
| |
| axs[1, 1].plot(df['iteration'], df['check_rate'], label='Check/step', color='purple') |
| axs[1, 1].plot(df['iteration'], df['capture_rate'], label='Capture/step', color='darkorange') |
| axs[1, 1].set_title('Attack Metrics (β = more aggressive)') |
| axs[1, 1].legend(); axs[1, 1].set_xlabel('Iteration') |
|
|
| |
| if elo_df is not None and len(elo_df) > 0: |
| axs[2, 0].plot(elo_df['iteration'], elo_df['elo'], |
| color='gold', linewidth=2.0, label='Agent ELO') |
| axs[2, 0].axhline(RANDOM_BASELINE_ELO, linestyle='--', |
| color='gray', alpha=0.8, label=f'Random ({RANDOM_BASELINE_ELO})') |
| axs[2, 0].axhline(1200, linestyle=':', color='lightblue', |
| alpha=0.6, label='Start (1200)') |
| axs[2, 0].fill_between(elo_df['iteration'], RANDOM_BASELINE_ELO, |
| elo_df['elo'], alpha=0.15, color='gold') |
| axs[2, 0].set_title('ELO Rating vs Random Baseline') |
| axs[2, 0].legend(); axs[2, 0].set_xlabel('Iteration') |
| else: |
| axs[2, 0].text(0.5, 0.5, f'ELO eval every {CONFIG["elo_eval_interval"]} iters', |
| ha='center', va='center', transform=axs[2, 0].transAxes, |
| color='gray', fontsize=11) |
| axs[2, 0].set_title('ELO Rating (pending)') |
|
|
| |
| axs[2, 1].plot(df['iteration'], df['avg_game_len'], color='teal', linewidth=1.2) |
| axs[2, 1].set_title('Avg Game Length (β = faster checkmates)') |
| axs[2, 1].set_xlabel('Iteration') |
|
|
| for ax in axs.flat: |
| ax.grid(True, alpha=0.25) |
|
|
| plt.tight_layout() |
| out = os.path.join(CONFIG["checkpoint_dir"], "training_performance.png") |
| plt.savefig(out, dpi=100, bbox_inches='tight') |
| plt.close(fig) |
| print(f" [Plot] saved β {out}") |
|
|
|
|
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Aggressive GRPO Chess Agent (T4/Colab)") |
| parser.add_argument("--iterations", type=int, default=10000, |
| help="Total training iterations") |
| parser.add_argument("--test-batch", action="store_true", |
| help="Run 2 iterations for smoke-test") |
| args, _ = parser.parse_known_args() |
|
|
| torch.manual_seed(CONFIG["seed"]) |
| np.random.seed(CONFIG["seed"]) |
| random.seed(CONFIG["seed"]) |
|
|
| |
| if torch.cuda.is_available(): |
| props = torch.cuda.get_device_properties(0) |
| print(f"GPU: {props.name} | VRAM: {props.total_memory/1e9:.1f}GB | " |
| f"SM: {props.multi_processor_count} | " |
| f"Compute: {props.major}.{props.minor}") |
|
|
| trainer = GRPOTrainer() |
| trainer.train(2 if args.test_batch else args.iterations) |