algorembrant commited on
Commit
73d8b96
Β·
verified Β·
1 Parent(s): 634251f

Upload model_aggressive.py

Browse files
Files changed (1) hide show
  1. model_aggressive.py +769 -0
model_aggressive.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE FOR COLAB USERS: Run in a separate cell first:
2
+ # !pip -q install chess numpy torch matplotlib pandas
3
+
4
+ """
5
+ Aggressive GRPO Chess Agent β€” T4/Colab Optimized
6
+ """
7
+
8
+ import os, sys, csv, time, math, shutil, argparse, random
9
+ import numpy as np
10
+ import pandas as pd
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+ import matplotlib.pyplot as plt
14
+
15
+ try:
16
+ import chess
17
+ except ImportError:
18
+ os.system("pip install -q chess")
19
+ import chess
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ # ── Hardware flags ─────────────────────────────────────────────────────────────
26
+ torch.backends.cudnn.benchmark = True
27
+ torch.backends.cuda.matmul.allow_tf32 = True
28
+ torch.backends.cudnn.allow_tf32 = True
29
+ if hasattr(torch, 'set_float32_matmul_precision'):
30
+ torch.set_float32_matmul_precision('high')
31
+
32
+ # ── Constants ──────────────────────────────────────────────────────────────────
33
+ PIECE_VAL = {
34
+ chess.PAWN: 1.0, chess.KNIGHT: 3.0, chess.BISHOP: 3.2,
35
+ chess.ROOK: 5.0, chess.QUEEN: 9.0, chess.KING: 0.0,
36
+ }
37
+ RANDOM_BASELINE_ELO = 800 # estimated ELO of uniform-random player
38
+
39
+ CONFIG = {
40
+ "num_envs": 256,
41
+ "grpo_group_size": 8, # G envs per group, all start from same opening position
42
+ "ppo_epochs": 3,
43
+ "mini_batch_size": 4096,
44
+ "learning_rate": 2e-4,
45
+ "weight_decay": 1e-4,
46
+ "gamma": 0.98, # lower β†’ discount future more β†’ prefer fast wins
47
+ "clip_epsilon": 0.15,
48
+ "entropy_coef": 0.02, # low β†’ exploit aggressive lines
49
+ "value_coef": 0.5,
50
+ "max_steps": 100,
51
+ "opening_max_moves": 10, # randomize opening for GRPO diversity
52
+ "checkpoint_dir": "./checkpoints",
53
+ "save_interval": 50,
54
+ "log_interval": 1,
55
+ "elo_eval_interval": 100, # evaluate ELO every N iterations
56
+ "elo_eval_games": 32,
57
+ "max_runtime_hours": 4.5, # auto-save + download before Colab kills session
58
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
59
+ "seed": 42,
60
+ }
61
+
62
+ # ── Action Space ───────────────────────────────────────────────────────────────
63
+ class ActionMapper:
64
+ __slots__ = ['move_to_idx', 'idx_to_move', 'num_actions']
65
+ def __init__(self):
66
+ self.move_to_idx: dict[str, int] = {}
67
+ self.idx_to_move: list[str] = []
68
+ idx = 0
69
+ for f in range(64):
70
+ for t in range(64):
71
+ if f == t: continue
72
+ uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t]
73
+ self.move_to_idx[uci] = idx
74
+ self.idx_to_move.append(uci)
75
+ idx += 1
76
+ if chess.square_rank(f) in (1, 6) and \
77
+ abs(chess.square_file(f) - chess.square_file(t)) <= 1:
78
+ for promo in "nbrq":
79
+ puci = uci + promo
80
+ self.move_to_idx[puci] = idx
81
+ self.idx_to_move.append(puci)
82
+ idx += 1
83
+ self.num_actions = idx
84
+
85
+ ACTION_MAPPER = ActionMapper()
86
+
87
+ # ── Board Encoding ─────────────────────────────────────────────────────────────
88
+ def populate_states_fast(envs: list, active_mask: np.ndarray,
89
+ bbs_np: np.ndarray, meta_np: np.ndarray) -> None:
90
+ """Fill bbs_np [B,12] int64 and meta_np [B,3] float32 for active envs."""
91
+ for b in range(len(envs)):
92
+ if not active_mask[b]: continue
93
+ env = envs[b]
94
+ w = env.occupied_co[chess.WHITE]
95
+ bc = env.occupied_co[chess.BLACK]
96
+ bbs_np[b, 0] = env.pawns & w; bbs_np[b, 1] = env.knights & w
97
+ bbs_np[b, 2] = env.bishops & w; bbs_np[b, 3] = env.rooks & w
98
+ bbs_np[b, 4] = env.queens & w; bbs_np[b, 5] = env.kings & w
99
+ bbs_np[b, 6] = env.pawns & bc; bbs_np[b, 7] = env.knights & bc
100
+ bbs_np[b, 8] = env.bishops & bc; bbs_np[b, 9] = env.rooks & bc
101
+ bbs_np[b, 10] = env.queens & bc; bbs_np[b, 11] = env.kings & bc
102
+ meta_np[b, 0] = 1.0 if env.turn else -1.0
103
+ meta_np[b, 1] = float(env.castling_rights) / 15.0 # [0,1]
104
+ meta_np[b, 2] = 1.0 if env.ep_square is not None else 0.0
105
+
106
+ def get_legal_masks(envs: list, active_mask: np.ndarray):
107
+ masks = np.zeros((len(envs), ACTION_MAPPER.num_actions), dtype=np.bool_)
108
+ moves_list = [None] * len(envs)
109
+ for b in range(len(envs)):
110
+ if not active_mask[b]: continue
111
+ legal = list(envs[b].legal_moves)
112
+ moves_list[b] = legal
113
+ for m in legal:
114
+ masks[b, ACTION_MAPPER.move_to_idx[m.uci()]] = True
115
+ return masks, moves_list
116
+
117
+ # ── Neural Network ─────────────────────────────────────────────────────────────
118
+ class ChessNet(nn.Module):
119
+ def __init__(self, res_blocks: int = 8, channels: int = 128):
120
+ super().__init__()
121
+ self.conv_in = nn.Conv2d(14, channels, 3, padding=1, bias=False)
122
+ self.bn_in = nn.BatchNorm2d(channels)
123
+ self.res_blocks = nn.ModuleList([
124
+ nn.Sequential(
125
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False),
126
+ nn.BatchNorm2d(channels), nn.ReLU(inplace=True),
127
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False),
128
+ nn.BatchNorm2d(channels),
129
+ ) for _ in range(res_blocks)
130
+ ])
131
+ self.policy_head = nn.Sequential(
132
+ nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32),
133
+ nn.ReLU(inplace=True), nn.Flatten(),
134
+ nn.Linear(32 * 64, ACTION_MAPPER.num_actions),
135
+ )
136
+ # No Tanh β€” shaped rewards exceed [-1,1]; unbounded linear output
137
+ self.value_head = nn.Sequential(
138
+ nn.Conv2d(channels, 32, 1, bias=False), nn.BatchNorm2d(32),
139
+ nn.ReLU(inplace=True), nn.Flatten(),
140
+ nn.Linear(32 * 64, 256), nn.ReLU(inplace=True),
141
+ nn.Linear(256, 1),
142
+ )
143
+
144
+ def forward(self, x):
145
+ x = F.relu(self.bn_in(self.conv_in(x)), inplace=True)
146
+ for blk in self.res_blocks:
147
+ x = F.relu(x + blk(x), inplace=True)
148
+ return self.policy_head(x), self.value_head(x)
149
+
150
+ # ── ELO Tracker ───────────────────────────────────────────────────────────────
151
+ class ELOTracker:
152
+ def __init__(self, initial_elo: float = 1200.0, K: float = 32.0):
153
+ self.elo = initial_elo
154
+ self.K = K
155
+
156
+ def expected(self, opp_elo: float) -> float:
157
+ return 1.0 / (1.0 + 10.0 ** ((opp_elo - self.elo) / 400.0))
158
+
159
+ def update(self, score: float, opp_elo: float) -> None:
160
+ self.elo += self.K * (score - self.expected(opp_elo))
161
+
162
+ # ── Opening Position Generator ─────────────────────────────────────────────────
163
+ def get_opening_position(max_moves: int = 10) -> chess.Board:
164
+ """Play 0..max_moves random half-moves from start for GRPO diversity."""
165
+ board = chess.Board()
166
+ for _ in range(random.randint(0, max_moves)):
167
+ if board.is_game_over(): break
168
+ board.push(random.choice(list(board.legal_moves)))
169
+ return chess.Board(board.fen()) # detached copy
170
+
171
+ # ── Auto-download ──────────────────────────────────────────────────────────────
172
+ def auto_download(checkpoint_dir: str) -> None:
173
+ """Sync to Google Drive if mounted, else trigger browser downloads."""
174
+ try:
175
+ from google.colab import files as _cf
176
+ drive_dst = '/content/drive/MyDrive/chess_agent'
177
+ if os.path.exists('/content/drive/MyDrive'):
178
+ os.makedirs(drive_dst, exist_ok=True)
179
+ shutil.copytree(checkpoint_dir, drive_dst, dirs_exist_ok=True)
180
+ print(f"[AutoSave] Synced β†’ {drive_dst}")
181
+ else:
182
+ for fname in ['best.pt', 'latest.pt', 'training_log.csv',
183
+ 'elo_log.csv', 'training_performance.png']:
184
+ fpath = os.path.join(checkpoint_dir, fname)
185
+ if os.path.exists(fpath):
186
+ _cf.download(fpath)
187
+ print(f"[AutoSave] Downloaded {fname}")
188
+ except Exception as e:
189
+ print(f"[AutoSave] {e}")
190
+
191
+ # ── GRPO Trainer ───────────────────────────────────────────────────────────────
192
+ class GRPOTrainer:
193
+
194
+ def __init__(self):
195
+ self.device = CONFIG["device"]
196
+
197
+ _model = ChessNet(res_blocks=8, channels=128)
198
+ _model = _model.to(self.device).to(memory_format=torch.channels_last)
199
+ try:
200
+ print("Compiling model (reduce-overhead)…")
201
+ self.model = torch.compile(_model, mode="reduce-overhead")
202
+ except Exception:
203
+ self.model = _model
204
+
205
+ self.optimizer = torch.optim.AdamW(
206
+ self.model.parameters(),
207
+ lr=CONFIG["learning_rate"],
208
+ weight_decay=CONFIG["weight_decay"],
209
+ fused=torch.cuda.is_available(),
210
+ )
211
+ self.scaler = torch.amp.GradScaler('cuda')
212
+ self.start_iter = 0
213
+ self.best_win_rate = 0.0
214
+ self.elo_tracker = ELOTracker()
215
+
216
+ # Shared shift tensor for bit-unpacking (avoid repeated allocation)
217
+ self.shifts = torch.arange(64, dtype=torch.int64,
218
+ device=self.device).view(1, 1, 64)
219
+
220
+ os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)
221
+ self.log_file = os.path.join(CONFIG["checkpoint_dir"], "training_log.csv")
222
+ self.elo_log_file = os.path.join(CONFIG["checkpoint_dir"], "elo_log.csv")
223
+
224
+ if not os.path.exists(self.log_file):
225
+ with open(self.log_file, "w", newline="") as f:
226
+ csv.writer(f).writerow([
227
+ "iteration", "p_loss", "v_loss", "v_mean", "fps",
228
+ "win_rate", "draw_rate", "check_rate", "capture_rate", "avg_game_len",
229
+ ])
230
+ if not os.path.exists(self.elo_log_file):
231
+ with open(self.elo_log_file, "w", newline="") as f:
232
+ csv.writer(f).writerow(
233
+ ["iteration", "elo", "eval_wins", "eval_draws", "eval_losses"])
234
+
235
+ self._init_checkpointing()
236
+
237
+ # ── Checkpointing ──────────────────────────────────────────────────────────
238
+ def _init_checkpointing(self) -> None:
239
+ latest = os.path.join(CONFIG["checkpoint_dir"], "latest.pt")
240
+ if not os.path.exists(latest):
241
+ return
242
+ try:
243
+ ckpt = torch.load(latest, map_location=self.device, weights_only=False)
244
+ sd = ckpt['model_state_dict']
245
+ # Handle compiled (_orig_mod. prefix) vs uncompiled state dicts
246
+ loaded = False
247
+ for attempt in [
248
+ sd,
249
+ {k.replace('_orig_mod.', ''): v for k, v in sd.items()},
250
+ {'_orig_mod.' + k: v for k, v in sd.items()},
251
+ ]:
252
+ try:
253
+ self.model.load_state_dict(attempt); loaded = True; break
254
+ except RuntimeError:
255
+ continue
256
+ if not loaded:
257
+ raise RuntimeError("All state dict key variants failed.")
258
+ self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
259
+ self.scaler.load_state_dict(ckpt['scaler_state_dict'])
260
+ self.start_iter = ckpt.get('iteration', 0) + 1
261
+ self.elo_tracker.elo = ckpt.get('elo', 1200.0)
262
+ self.best_win_rate = ckpt.get('best_win_rate', 0.0)
263
+ print(f"Resumed from iter {self.start_iter} | "
264
+ f"ELO {self.elo_tracker.elo:.0f} | best_win {self.best_win_rate:.3f}")
265
+ except Exception as e:
266
+ print(f"Checkpoint load failed ({e}). Starting fresh.")
267
+
268
+ def save_checkpoint(self, iteration: int, is_best: bool = False) -> None:
269
+ ckpt = {
270
+ 'iteration': iteration,
271
+ 'model_state_dict': self.model.state_dict(),
272
+ 'optimizer_state_dict': self.optimizer.state_dict(),
273
+ 'scaler_state_dict': self.scaler.state_dict(),
274
+ 'elo': self.elo_tracker.elo,
275
+ 'best_win_rate': self.best_win_rate,
276
+ 'config': CONFIG,
277
+ }
278
+ cdir = CONFIG["checkpoint_dir"]
279
+ path = os.path.join(cdir, f"iter_{iteration:04d}.pt")
280
+ # Atomic write: write to .tmp then os.replace (single syscall, crash-safe)
281
+ torch.save(ckpt, path + ".tmp"); os.replace(path + ".tmp", path)
282
+ latest = os.path.join(cdir, "latest.pt")
283
+ shutil.copy2(path, latest + ".tmp"); os.replace(latest + ".tmp", latest)
284
+ if is_best:
285
+ best = os.path.join(cdir, "best.pt")
286
+ shutil.copy2(path, best + ".tmp"); os.replace(best + ".tmp", best)
287
+
288
+ # ── ELO Evaluation (batched, greedy) ──────────────────────────────────────
289
+ def _elo_game_done(self, board: chess.Board, idx: int, agent_color,
290
+ scores: np.ndarray, active: np.ndarray) -> None:
291
+ if board.is_game_over():
292
+ res = board.result()
293
+ if (res == "1-0" and agent_color == chess.WHITE) or \
294
+ (res == "0-1" and agent_color == chess.BLACK):
295
+ scores[idx] = 1.0
296
+ elif res == "1/2-1/2":
297
+ scores[idx] = 0.5
298
+ else:
299
+ scores[idx] = 0.0
300
+ active[idx] = False
301
+
302
+ def evaluate_elo(self, n_games: int = 32, max_ply: int = 200) -> tuple:
303
+ """
304
+ Play n_games vs random opponent (batched GPU for agent moves).
305
+ Half games as White, half as Black.
306
+ Returns (wins, draws, losses) from agent's perspective.
307
+ """
308
+ self.model.eval()
309
+ boards = [chess.Board() for _ in range(n_games)]
310
+ agent_colors = [chess.WHITE if i % 2 == 0 else chess.BLACK
311
+ for i in range(n_games)]
312
+ scores = np.full(n_games, 0.5, dtype=np.float32) # default: draw
313
+ active = np.ones(n_games, dtype=bool)
314
+ bbs_sub = np.zeros((n_games, 12), dtype=np.int64)
315
+ meta_sub= np.zeros((n_games, 3), dtype=np.float32)
316
+
317
+ for _ in range(max_ply):
318
+ if not active.any(): break
319
+
320
+ # Random moves (opponent turns) β€” CPU
321
+ for i in [i for i in range(n_games)
322
+ if active[i] and boards[i].turn != agent_colors[i]]:
323
+ legal = list(boards[i].legal_moves)
324
+ if legal: boards[i].push(random.choice(legal))
325
+ self._elo_game_done(boards[i], i, agent_colors[i], scores, active)
326
+
327
+ # Agent moves (batched GPU)
328
+ ag_idx = [i for i in range(n_games)
329
+ if active[i] and boards[i].turn == agent_colors[i]]
330
+ if not ag_idx:
331
+ continue
332
+
333
+ n = len(ag_idx)
334
+ sub = [boards[i] for i in ag_idx]
335
+ act_sub = np.ones(n, dtype=bool)
336
+ populate_states_fast(sub, act_sub, bbs_sub[:n], meta_sub[:n])
337
+
338
+ bbs_t = torch.tensor(bbs_sub[:n], dtype=torch.int64, device=self.device)
339
+ unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).float().view(n, 12, 8, 8)
340
+ state = torch.zeros(n, 14, 8, 8, device=self.device, dtype=torch.float32)
341
+ state[:, :12] = unpacked
342
+ state[:, 12] = torch.tensor(meta_sub[:n, 0], device=self.device).view(n, 1, 1).expand(n, 8, 8)
343
+ state[:, 13] = torch.tensor(meta_sub[:n, 1], device=self.device).view(n, 1, 1).expand(n, 8, 8)
344
+ for lj in range(n):
345
+ if meta_sub[lj, 2]:
346
+ state[lj, 13, 0, 1] = float(meta_sub[lj, 2])
347
+
348
+ with torch.no_grad(), torch.amp.autocast('cuda'):
349
+ logits, _ = self.model(state.to(memory_format=torch.channels_last))
350
+ logits = logits.float()
351
+
352
+ masks_np, legal_lists = get_legal_masks(sub, act_sub)
353
+ masks_t = torch.tensor(masks_np, dtype=torch.bool, device=self.device)
354
+ logits = torch.where(masks_t, logits,
355
+ torch.tensor(-60000.0, device=self.device))
356
+ best_acts = logits.argmax(dim=-1).cpu().numpy() # greedy for evaluation
357
+
358
+ for lj, gi in enumerate(ag_idx):
359
+ if not active[gi]: continue
360
+ move_uci = ACTION_MAPPER.idx_to_move[best_acts[lj]]
361
+ move = chess.Move.from_uci(move_uci)
362
+ legal = legal_lists[lj] or list(boards[gi].legal_moves)
363
+ if not legal:
364
+ active[gi] = False; continue
365
+ if move not in legal:
366
+ move = random.choice(legal)
367
+ boards[gi].push(move)
368
+ self._elo_game_done(boards[gi], gi, agent_colors[gi], scores, active)
369
+
370
+ wins = int((scores == 1.0).sum())
371
+ draws = int((scores == 0.5).sum())
372
+ losses = int((scores == 0.0).sum())
373
+ for s in scores:
374
+ self.elo_tracker.update(float(s), RANDOM_BASELINE_ELO)
375
+ return wins, draws, losses
376
+
377
+ # ── Main Training Loop ─────────────────────────────────────────────────────
378
+ def train(self, num_iterations: int) -> None:
379
+ B = CONFIG["num_envs"]
380
+ max_steps = CONFIG["max_steps"]
381
+ G = CONFIG["grpo_group_size"]
382
+ num_groups= B // G
383
+ gamma = CONFIG["gamma"]
384
+ t_start = time.time()
385
+ max_rt = CONFIG["max_runtime_hours"] * 3600.0
386
+
387
+ # ── Preallocate GPU buffers (int8/bool minimizes VRAM footprint) ──────
388
+ states_buf = torch.zeros((max_steps, B, 14, 8, 8), dtype=torch.int8, device=self.device)
389
+ actions_buf = torch.zeros((max_steps, B), dtype=torch.int16, device=self.device)
390
+ logprobs_buf= torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
391
+ values_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
392
+ rewards_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
393
+ dones_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device)
394
+ active_buf = torch.zeros((max_steps, B), dtype=torch.bool, device=self.device)
395
+
396
+ bbs_np = np.zeros((B, 12), dtype=np.int64) # int64: no astype copy needed
397
+ meta_np = np.zeros((B, 3), dtype=np.float32)
398
+
399
+ vram_gb = (torch.cuda.get_device_properties(0).total_memory / 1e9
400
+ if torch.cuda.is_available() else 0.0)
401
+ print(f"\nπŸš€ Aggressive GRPO Chess Agent")
402
+ print(f" Envs:{B} | Groups:{num_groups}Γ—G:{G} | Device:{self.device.upper()} | "
403
+ f"VRAM:{vram_gb:.1f}GB")
404
+ print(f" Reward: capture(0-0.3)+check(0.3)+checkmate_speed(1.0-1.5)"
405
+ f"+draw_penalty(-0.5)+time(-0.003/step)")
406
+ print(f" gamma:{gamma} | entropy:{CONFIG['entropy_coef']} | "
407
+ f"lr:{CONFIG['learning_rate']}")
408
+
409
+ for iteration in range(self.start_iter, num_iterations):
410
+
411
+ # ── Runtime guard ──────────────────────────────────────────────
412
+ elapsed = time.time() - t_start
413
+ if elapsed > max_rt:
414
+ print(f"\n⏱ {elapsed/3600:.2f}h reached. Saving & downloading…")
415
+ self.save_checkpoint(iteration)
416
+ self.plot_metrics()
417
+ auto_download(CONFIG["checkpoint_dir"])
418
+ break
419
+
420
+ iter_start = time.time()
421
+
422
+ # Zero buffers in-place (no reallocation)
423
+ states_buf.zero_(); actions_buf.zero_(); logprobs_buf.zero_()
424
+ values_buf.zero_(); rewards_buf.zero_()
425
+ dones_buf.fill_(False); active_buf.fill_(False)
426
+
427
+ # ── GRPO: each group of G envs shares an opening position ──────
428
+ fens = [get_opening_position(CONFIG["opening_max_moves"]).fen()
429
+ for _ in range(num_groups)]
430
+ envs: list[chess.Board] = []
431
+ for gi in range(num_groups):
432
+ for _ in range(G):
433
+ envs.append(chess.Board(fens[gi]))
434
+
435
+ active = np.ones(B, dtype=bool)
436
+ game_lengths = np.zeros(B, dtype=np.int32)
437
+
438
+ # Per-iteration attack metrics
439
+ white_wins = black_wins = draws_count = 0
440
+ total_checks = total_captures = 0
441
+
442
+ # ── PHASE 1: ROLLOUT ───────────────────────────────────────────
443
+ for t in range(max_steps):
444
+ if not active.any(): break
445
+
446
+ populate_states_fast(envs, active, bbs_np, meta_np)
447
+
448
+ # Bit-unpack bitboards β†’ int8 state tensor (no float copy)
449
+ bbs_t = torch.as_tensor(bbs_np, dtype=torch.int64, device=self.device)
450
+ unpacked = ((bbs_t.unsqueeze(-1) >> self.shifts) & 1).to(torch.int8)
451
+ meta_t = torch.as_tensor(meta_np, dtype=torch.float32, device=self.device)
452
+
453
+ # Pack into int8 buffer (scale float meta to [-127,127])
454
+ states_buf[t, :, :12, :, :] = unpacked.view(B, 12, 8, 8)
455
+ states_buf[t, :, 12, :, :] = (meta_t[:, 0] * 127).clamp(-127, 127) \
456
+ .to(torch.int8).view(B, 1, 1).expand(B, 8, 8)
457
+ states_buf[t, :, 13, :, :] = (meta_t[:, 1] * 127).clamp(0, 127) \
458
+ .to(torch.int8).view(B, 1, 1).expand(B, 8, 8)
459
+ states_buf[t, :, 13, 0, 1]= (meta_t[:, 2] * 127).clamp(0, 127).to(torch.int8)
460
+ active_buf[t] = torch.as_tensor(active, dtype=torch.bool, device=self.device)
461
+
462
+ # Normalize int8β†’float32 for forward pass
463
+ model_input = states_buf[t].to(
464
+ dtype=torch.float32, memory_format=torch.channels_last) / 127.0
465
+
466
+ self.model.eval()
467
+ with torch.no_grad(), torch.amp.autocast('cuda'):
468
+ logits, values = self.model(model_input)
469
+
470
+ masks_np, legal_moves_list = get_legal_masks(envs, active)
471
+ masks_t = torch.as_tensor(masks_np, dtype=torch.bool, device=self.device)
472
+ logits = logits.float()
473
+ logits = torch.where(masks_t, logits,
474
+ torch.tensor(-60000.0, device=self.device))
475
+ no_legal = ~masks_t.any(dim=-1, keepdim=True)
476
+ logits.masked_fill_(no_legal, 0.0)
477
+
478
+ probs = F.softmax(logits, dim=-1)
479
+ dist = torch.distributions.Categorical(probs)
480
+ actions = dist.sample()
481
+
482
+ actions_buf[t] = actions.to(torch.int16)
483
+ logprobs_buf[t] = dist.log_prob(actions)
484
+ values_buf[t] = values.squeeze(-1)
485
+
486
+ actions_cpu = actions.cpu().numpy()
487
+
488
+ for b in range(B):
489
+ if not active[b]: continue
490
+
491
+ move_uci = ACTION_MAPPER.idx_to_move[actions_cpu[b]]
492
+ move = chess.Move.from_uci(move_uci)
493
+ if move not in legal_moves_list[b]:
494
+ move = random.choice(legal_moves_list[b])
495
+
496
+ board = envs[b]
497
+ mover_is_white = (board.turn == chess.WHITE)
498
+ sign = 1.0 if mover_is_white else -1.0
499
+
500
+ # ── Reward: pre-push components ─────────────────────
501
+ r = -0.003 * sign # time penalty (per-mover, white-perspective)
502
+
503
+ if board.is_capture(move):
504
+ if board.is_en_passant(move):
505
+ cap_val = 1.0
506
+ else:
507
+ cp = board.piece_at(move.to_square)
508
+ cap_val = PIECE_VAL.get(cp.piece_type, 0.0) if cp else 0.0
509
+ r += sign * (cap_val / 9.0) * 0.3 # [0, 0.3]
510
+ total_captures += 1
511
+
512
+ if move.promotion in (chess.QUEEN, chess.ROOK):
513
+ r += sign * 0.15 # aggressive promotion
514
+
515
+ board.push(move)
516
+ game_lengths[b] += 1
517
+
518
+ # ── Reward: post-push components ────────────────────
519
+ if board.is_check():
520
+ r += sign * 0.3 # gave check
521
+ total_checks += 1
522
+
523
+ if board.is_game_over():
524
+ if board.is_checkmate():
525
+ # Mover delivered checkmate
526
+ speed_bonus = 0.5 * math.exp(-game_lengths[b] / 20.0)
527
+ r += sign * (1.0 + speed_bonus) # ~1.0-1.5
528
+ if mover_is_white: white_wins += 1
529
+ else: black_wins += 1
530
+ else:
531
+ # Draw (stalemate / 50-move / repetition / insufficient material)
532
+ r -= 0.5 # flat penalty from white's perspective β€” attack to WIN
533
+ draws_count += 1
534
+ dones_buf[t, b] = True
535
+ active[b] = False
536
+
537
+ rewards_buf[t, b] = r
538
+ # end per-env loop
539
+ # end rollout
540
+
541
+ # ── PHASE 2: VECTORIZED RETURNS ────────────────────────────────
542
+ returns = torch.zeros(B, dtype=torch.float32, device=self.device)
543
+ returns_buf = torch.zeros((max_steps, B), dtype=torch.float32, device=self.device)
544
+ not_done_f = (~dones_buf).float()
545
+ for step in reversed(range(max_steps)):
546
+ returns = rewards_buf[step] + gamma * returns * not_done_f[step]
547
+ returns_buf[step]= returns
548
+
549
+ # ── PHASE 3: GRPO GROUP-WISE ADVANTAGE NORMALIZATION ───────────
550
+ # advantages shape [max_steps, B]
551
+ adv_raw = returns_buf - values_buf
552
+ active_f = active_buf.float()
553
+
554
+ # Reshape to [max_steps, num_groups, G] and normalize within each group
555
+ adv_3d = adv_raw.view(max_steps, num_groups, G)
556
+ act_3d = active_f.view(max_steps, num_groups, G)
557
+
558
+ g_count = act_3d.sum(dim=[0, 2]).clamp(min=1.0) # [num_groups]
559
+ g_mean = (adv_3d * act_3d).sum(dim=[0, 2]) / g_count # [num_groups]
560
+ g_sq_diff = ((adv_3d - g_mean.view(1, num_groups, 1)) ** 2
561
+ * act_3d).sum(dim=[0, 2])
562
+ g_std = (g_sq_diff / g_count).sqrt().clamp(min=1e-8) # [num_groups]
563
+ adv_3d = (adv_3d - g_mean.view(1, num_groups, 1)) / \
564
+ g_std.view(1, num_groups, 1)
565
+ adv_norm = adv_3d.view(max_steps, B)
566
+
567
+ # Flatten, filter to active steps only
568
+ valid_mask = active_buf.view(-1)
569
+ flat_states = (states_buf.view(-1, 14, 8, 8)[valid_mask]
570
+ .to(torch.float32, memory_format=torch.channels_last)
571
+ .div_(127.0))
572
+ flat_actions = actions_buf.view(-1)[valid_mask].to(torch.int64)
573
+ flat_old_lp = logprobs_buf.view(-1)[valid_mask]
574
+ flat_returns = returns_buf.view(-1)[valid_mask]
575
+ flat_advantages = adv_norm.view(-1)[valid_mask]
576
+
577
+ dataset_size = flat_states.size(0)
578
+ if dataset_size < 100:
579
+ continue # skip degenerate rollout (all games ended instantly)
580
+
581
+ # ── PHASE 4: PPO OPTIMIZATION ──────────────────────────────────
582
+ self.model.train()
583
+ total_p_loss = total_v_loss = 0.0
584
+ num_updates = 0
585
+ mb_size = CONFIG["mini_batch_size"]
586
+
587
+ for _ in range(CONFIG["ppo_epochs"]):
588
+ perm = torch.randperm(dataset_size, device=self.device)
589
+ for start in range(0, dataset_size, mb_size):
590
+ mb = perm[start: start + mb_size]
591
+ with torch.amp.autocast('cuda'):
592
+ new_logits, new_vals = self.model(flat_states[mb])
593
+ new_dist = torch.distributions.Categorical(logits=new_logits)
594
+ new_lp = new_dist.log_prob(flat_actions[mb])
595
+ ratio = torch.exp(new_lp - flat_old_lp[mb])
596
+ adv = flat_advantages[mb]
597
+ surr1 = ratio * adv
598
+ surr2 = torch.clamp(
599
+ ratio,
600
+ 1.0 - CONFIG["clip_epsilon"],
601
+ 1.0 + CONFIG["clip_epsilon"],
602
+ ) * adv
603
+ p_loss = -torch.min(surr1, surr2).mean()
604
+ v_loss = F.mse_loss(new_vals.squeeze(-1), flat_returns[mb])
605
+ entropy = new_dist.entropy().mean()
606
+ loss = (p_loss
607
+ + CONFIG["value_coef"] * v_loss
608
+ - CONFIG["entropy_coef"] * entropy)
609
+
610
+ self.optimizer.zero_grad(set_to_none=True)
611
+ self.scaler.scale(loss).backward()
612
+ self.scaler.unscale_(self.optimizer)
613
+ nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
614
+ self.scaler.step(self.optimizer)
615
+ self.scaler.update()
616
+
617
+ total_p_loss += p_loss.item()
618
+ total_v_loss += v_loss.item()
619
+ num_updates += 1
620
+
621
+ # ── PHASE 5: METRICS & LOGGING ────────────────────────────────
622
+ done_count = white_wins + black_wins + draws_count
623
+ win_rate = white_wins / max(done_count, 1)
624
+ draw_rate = draws_count / max(done_count, 1)
625
+ active_steps = int(active_buf.sum().item())
626
+ check_rate = total_checks / max(active_steps, 1)
627
+ capture_rate = total_captures / max(active_steps, 1)
628
+ avg_game_len = float(game_lengths.mean())
629
+ fps = dataset_size / max(time.time() - iter_start, 1e-3)
630
+
631
+ if (iteration + 1) % CONFIG["log_interval"] == 0:
632
+ vram_alloc = (torch.cuda.memory_allocated() / 1e9
633
+ if torch.cuda.is_available() else 0.0)
634
+ vram_res = (torch.cuda.memory_reserved() / 1e9
635
+ if torch.cuda.is_available() else 0.0)
636
+ print(
637
+ f"[{iteration+1:05d}] "
638
+ f"P:{total_p_loss/max(1,num_updates):.4f} "
639
+ f"V:{total_v_loss/max(1,num_updates):.4f} | "
640
+ f"W:{win_rate:.3f} D:{draw_rate:.3f} "
641
+ f"Chk:{check_rate:.4f} Cap:{capture_rate:.4f} "
642
+ f"Len:{avg_game_len:.1f} | "
643
+ f"ELO:{self.elo_tracker.elo:.0f} | "
644
+ f"FPS:{fps:.0f} | "
645
+ f"VRAM:{vram_alloc:.2f}/{vram_res:.2f}GB"
646
+ )
647
+ with open(self.log_file, "a", newline="") as f:
648
+ csv.writer(f).writerow([
649
+ iteration + 1,
650
+ total_p_loss / max(1, num_updates),
651
+ total_v_loss / max(1, num_updates),
652
+ flat_returns.mean().item(),
653
+ fps, win_rate, draw_rate,
654
+ check_rate, capture_rate, avg_game_len,
655
+ ])
656
+
657
+ # Save best checkpoint when win_rate improves
658
+ if win_rate > self.best_win_rate:
659
+ self.best_win_rate = win_rate
660
+ self.save_checkpoint(iteration + 1, is_best=True)
661
+
662
+ if (iteration + 1) % CONFIG["save_interval"] == 0:
663
+ self.save_checkpoint(iteration + 1)
664
+ self.plot_metrics()
665
+
666
+ # ELO evaluation
667
+ if (iteration + 1) % CONFIG["elo_eval_interval"] == 0:
668
+ elo_before = self.elo_tracker.elo
669
+ ew, ed, el = self.evaluate_elo(CONFIG["elo_eval_games"])
670
+ print(
671
+ f" [ELO eval] {elo_before:.0f} β†’ {self.elo_tracker.elo:.0f} | "
672
+ f"W:{ew} D:{ed} L:{el} vs random({RANDOM_BASELINE_ELO})"
673
+ )
674
+ with open(self.elo_log_file, "a", newline="") as f:
675
+ csv.writer(f).writerow(
676
+ [iteration + 1, self.elo_tracker.elo, ew, ed, el])
677
+ self.plot_metrics()
678
+
679
+ # Aggressive cache reclaim (free fragmented blocks, not pinned allocs)
680
+ torch.cuda.empty_cache()
681
+
682
+ # ── Plotting ───────────────────────────────────────────────────────────────
683
+ def plot_metrics(self) -> None:
684
+ if not os.path.exists(self.log_file): return
685
+ df = pd.read_csv(self.log_file)
686
+ if len(df) < 2: return
687
+
688
+ elo_df = None
689
+ if os.path.exists(self.elo_log_file):
690
+ elo_df = pd.read_csv(self.elo_log_file)
691
+
692
+ fig, axs = plt.subplots(3, 2, figsize=(14, 12))
693
+ fig.suptitle("Aggressive GRPO Chess Agent β€” Training Dashboard", fontsize=14)
694
+
695
+ # Row 0: Losses
696
+ axs[0, 0].plot(df['iteration'], df['p_loss'], color='steelblue', linewidth=1.2)
697
+ axs[0, 0].set_title('Policy Loss'); axs[0, 0].set_xlabel('Iteration')
698
+
699
+ axs[0, 1].plot(df['iteration'], df['v_loss'], color='tomato', linewidth=1.2)
700
+ axs[0, 1].set_title('Value Loss'); axs[0, 1].set_xlabel('Iteration')
701
+
702
+ # Row 1: Outcomes
703
+ axs[1, 0].plot(df['iteration'], df['win_rate'], label='Win', color='green')
704
+ axs[1, 0].plot(df['iteration'], df['draw_rate'], label='Draw', color='orange')
705
+ axs[1, 0].set_title('Outcomes (White perspective)')
706
+ axs[1, 0].legend(); axs[1, 0].set_xlabel('Iteration')
707
+
708
+ # Row 1: Attack metrics
709
+ axs[1, 1].plot(df['iteration'], df['check_rate'], label='Check/step', color='purple')
710
+ axs[1, 1].plot(df['iteration'], df['capture_rate'], label='Capture/step', color='darkorange')
711
+ axs[1, 1].set_title('Attack Metrics (↑ = more aggressive)')
712
+ axs[1, 1].legend(); axs[1, 1].set_xlabel('Iteration')
713
+
714
+ # Row 2: ELO Rating
715
+ if elo_df is not None and len(elo_df) > 0:
716
+ axs[2, 0].plot(elo_df['iteration'], elo_df['elo'],
717
+ color='gold', linewidth=2.0, label='Agent ELO')
718
+ axs[2, 0].axhline(RANDOM_BASELINE_ELO, linestyle='--',
719
+ color='gray', alpha=0.8, label=f'Random ({RANDOM_BASELINE_ELO})')
720
+ axs[2, 0].axhline(1200, linestyle=':', color='lightblue',
721
+ alpha=0.6, label='Start (1200)')
722
+ axs[2, 0].fill_between(elo_df['iteration'], RANDOM_BASELINE_ELO,
723
+ elo_df['elo'], alpha=0.15, color='gold')
724
+ axs[2, 0].set_title('ELO Rating vs Random Baseline')
725
+ axs[2, 0].legend(); axs[2, 0].set_xlabel('Iteration')
726
+ else:
727
+ axs[2, 0].text(0.5, 0.5, f'ELO eval every {CONFIG["elo_eval_interval"]} iters',
728
+ ha='center', va='center', transform=axs[2, 0].transAxes,
729
+ color='gray', fontsize=11)
730
+ axs[2, 0].set_title('ELO Rating (pending)')
731
+
732
+ # Row 2: Average game length
733
+ axs[2, 1].plot(df['iteration'], df['avg_game_len'], color='teal', linewidth=1.2)
734
+ axs[2, 1].set_title('Avg Game Length (↓ = faster checkmates)')
735
+ axs[2, 1].set_xlabel('Iteration')
736
+
737
+ for ax in axs.flat:
738
+ ax.grid(True, alpha=0.25)
739
+
740
+ plt.tight_layout()
741
+ out = os.path.join(CONFIG["checkpoint_dir"], "training_performance.png")
742
+ plt.savefig(out, dpi=100, bbox_inches='tight')
743
+ plt.close(fig)
744
+ print(f" [Plot] saved β†’ {out}")
745
+
746
+
747
+ # ── Entry Point ────────────────────────────────────────────────────────────────
748
+ if __name__ == "__main__":
749
+ parser = argparse.ArgumentParser(
750
+ description="Aggressive GRPO Chess Agent (T4/Colab)")
751
+ parser.add_argument("--iterations", type=int, default=10000,
752
+ help="Total training iterations")
753
+ parser.add_argument("--test-batch", action="store_true",
754
+ help="Run 2 iterations for smoke-test")
755
+ args, _ = parser.parse_known_args()
756
+
757
+ torch.manual_seed(CONFIG["seed"])
758
+ np.random.seed(CONFIG["seed"])
759
+ random.seed(CONFIG["seed"])
760
+
761
+ # Print VRAM summary at startup
762
+ if torch.cuda.is_available():
763
+ props = torch.cuda.get_device_properties(0)
764
+ print(f"GPU: {props.name} | VRAM: {props.total_memory/1e9:.1f}GB | "
765
+ f"SM: {props.multi_processor_count} | "
766
+ f"Compute: {props.major}.{props.minor}")
767
+
768
+ trainer = GRPOTrainer()
769
+ trainer.train(2 if args.test_batch else args.iterations)