algorembrant commited on
Commit
1af5cb8
·
verified ·
1 Parent(s): 1cb5066

Upload test.py

Browse files
Files changed (1) hide show
  1. test.py +480 -0
test.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !pip -q install pygame chess torch numpy
2
+
3
+ """
4
+ Interactive Chess GUI with GRPO Model Predictions
5
+ - Movable pieces (drag & drop)
6
+ - Arrow showing top-3 predicted moves
7
+ - Legal move enforcement
8
+ STANDALONE VERSION: Contains necessary model classes to run without model.py.
9
+ """
10
+
11
+ import sys
12
+ import os
13
+ import threading
14
+ import queue
15
+ import math
16
+ from typing import List, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import numpy as np
22
+
23
+ os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
24
+ try:
25
+ import pygame
26
+ import chess
27
+ except ImportError:
28
+ os.system("pip install -q pygame chess")
29
+ import pygame
30
+ import chess
31
+
32
+ # ----------------------------------------------------------------------
33
+ # Core High-Performance Flags
34
+ # ----------------------------------------------------------------------
35
+ torch.backends.cudnn.benchmark = True
36
+ torch.backends.cuda.matmul.allow_tf32 = True
37
+ torch.backends.cudnn.allow_tf32 = True
38
+ if hasattr(torch, 'set_float32_matmul_precision'):
39
+ torch.set_float32_matmul_precision('high')
40
+
41
+
42
+ # ----------------------------------------------------------------------
43
+ # Model Components (Included for standalone execution)
44
+ # ----------------------------------------------------------------------
45
+ class ActionMapper:
46
+ __slots__ = ['move_to_idx', 'idx_to_move', 'num_actions']
47
+ def __init__(self):
48
+ self.move_to_idx = {}
49
+ self.idx_to_move = []
50
+ idx = 0
51
+ for f in range(64):
52
+ for t in range(64):
53
+ if f == t: continue
54
+ uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t]
55
+ self.move_to_idx[uci] = idx
56
+ self.idx_to_move.append(uci)
57
+ idx += 1
58
+ if chess.square_rank(f) in (1, 6) and abs(chess.square_file(f) - chess.square_file(t)) <= 1:
59
+ for promo in "nbrq":
60
+ promo_uci = uci + promo
61
+ self.move_to_idx[promo_uci] = idx
62
+ self.idx_to_move.append(promo_uci)
63
+ idx += 1
64
+ self.num_actions = idx
65
+
66
+ ACTION_MAPPER = ActionMapper()
67
+
68
+ class ChessNet(nn.Module):
69
+ def __init__(self):
70
+ super().__init__()
71
+ self.conv_in = nn.Conv2d(14, 128, kernel_size=3, padding=1, bias=False)
72
+ self.bn_in = nn.BatchNorm2d(128)
73
+
74
+ self.res_blocks = nn.ModuleList([
75
+ nn.Sequential(
76
+ nn.Conv2d(128, 128, 3, padding=1, bias=False),
77
+ nn.BatchNorm2d(128),
78
+ nn.ReLU(inplace=True),
79
+ nn.Conv2d(128, 128, 3, padding=1, bias=False),
80
+ nn.BatchNorm2d(128)
81
+ ) for _ in range(6)
82
+ ])
83
+
84
+ self.policy_head = nn.Sequential(
85
+ nn.Conv2d(128, 32, 1, bias=False),
86
+ nn.BatchNorm2d(32),
87
+ nn.ReLU(inplace=True),
88
+ nn.Flatten(),
89
+ nn.Linear(32 * 8 * 8, ACTION_MAPPER.num_actions)
90
+ )
91
+
92
+ self.value_head = nn.Sequential(
93
+ nn.Conv2d(128, 32, 1, bias=False),
94
+ nn.BatchNorm2d(32),
95
+ nn.ReLU(inplace=True),
96
+ nn.Flatten(),
97
+ nn.Linear(32 * 8 * 8, 256),
98
+ nn.ReLU(inplace=True),
99
+ nn.Linear(256, 1),
100
+ nn.Tanh()
101
+ )
102
+
103
+ def forward(self, x):
104
+ x = F.relu(self.bn_in(self.conv_in(x)), inplace=True)
105
+ for block in self.res_blocks:
106
+ x = F.relu(x + block(x), inplace=True)
107
+ return self.policy_head(x), self.value_head(x)
108
+
109
+ def boards_to_tensor_vectorized(envs: List[chess.Board], out_tensor: np.ndarray):
110
+ B = len(envs)
111
+ bbs = np.zeros((B, 12), dtype=np.uint64)
112
+ meta = np.zeros((B, 3), dtype=np.float32)
113
+
114
+ for b, env in enumerate(envs):
115
+ w = env.occupied_co[chess.WHITE]
116
+ bc = env.occupied_co[chess.BLACK]
117
+
118
+ bbs[b, 0] = env.pawns & w
119
+ bbs[b, 1] = env.knights & w
120
+ bbs[b, 2] = env.bishops & w
121
+ bbs[b, 3] = env.rooks & w
122
+ bbs[b, 4] = env.queens & w
123
+ bbs[b, 5] = env.kings & w
124
+ bbs[b, 6] = env.pawns & bc
125
+ bbs[b, 7] = env.knights & bc
126
+ bbs[b, 8] = env.bishops & bc
127
+ bbs[b, 9] = env.rooks & bc
128
+ bbs[b, 10] = env.queens & bc
129
+ bbs[b, 11] = env.kings & bc
130
+
131
+ meta[b, 0] = 1.0 if env.turn else -1.0
132
+ meta[b, 1] = env.castling_rights * 0.1333333 - 1.0
133
+ meta[b, 2] = 1.0 if env.ep_square else -1.0
134
+
135
+ # Bit unpacking (equivalent to model.py torch logic)
136
+ bbs_bytes = bbs.view(np.uint8).reshape(B, 12, 8)
137
+ unpacked = np.unpackbits(bbs_bytes, axis=2, bitorder='little').reshape(B, 12, 8, 8)
138
+
139
+ # Quantize meta to int8 to match model.py training buffer behavior
140
+ meta_int8 = meta.astype(np.int8)
141
+
142
+ out_tensor[:, :12, :, :] = unpacked.astype(np.float32)
143
+
144
+ # Channel 12: Turn (filled 8x8)
145
+ out_tensor[:, 12, :, :] = meta_int8[:, 0].reshape(B, 1, 1)
146
+
147
+ # Channel 13: Castling (filled 8x8) then EP at (0, 1)
148
+ out_tensor[:, 13, :, :] = meta_int8[:, 1].reshape(B, 1, 1)
149
+ for b in range(B):
150
+ out_tensor[b, 13, 0, 1] = meta_int8[b, 2]
151
+
152
+
153
+ # ----------------------------------------------------------------------
154
+ # Pygame Initialization
155
+ # ----------------------------------------------------------------------
156
+ # Suppress video driver errors if running in headless Colab environment
157
+ if "google.colab" in sys.modules:
158
+ print("WARNING: You are running test.py in Google Colab.")
159
+ print("Pygame requires a GUI display which Colab does not have natively.")
160
+ print("It is recommended to run test.py locally on your Windows PC and load the latest.pt file.")
161
+
162
+ try:
163
+ pygame.init()
164
+ screen_test = pygame.display.set_mode((1, 1))
165
+ pygame.display.quit()
166
+ HAS_DISPLAY = True
167
+ except pygame.error:
168
+ print("ERROR: No display detected. Pygame GUI cannot start without a screen.")
169
+ HAS_DISPLAY = False
170
+
171
+ WIDTH, HEIGHT = 800, 800
172
+ SQUARE_SIZE = WIDTH // 8
173
+ FPS = 60
174
+
175
+ LIGHT = (240, 217, 181)
176
+ DARK = (181, 136, 99)
177
+ HIGHLIGHT = (255, 255, 0, 100)
178
+ ARROW_COLOR = (50, 150, 250)
179
+ TEXT_COLOR = (0, 0, 0)
180
+
181
+ def create_piece_surface(piece: chess.Piece) -> pygame.Surface:
182
+ surf = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA)
183
+
184
+ # Try Windows-specific Unicode fonts first, fallback to default
185
+ font_names = ["segoeuisymbol", "arial", "msgothic"]
186
+ font = None
187
+ is_default = False
188
+ for fn in font_names:
189
+ if fn in pygame.font.get_fonts():
190
+ font = pygame.font.SysFont(fn, int(SQUARE_SIZE * 0.7))
191
+ break
192
+
193
+ if font is None:
194
+ font = pygame.font.Font(None, int(SQUARE_SIZE * 0.6))
195
+ is_default = True
196
+
197
+ symbols = {
198
+ 'P': '♙', 'N': '♘', 'B': '♗', 'R': '♖', 'Q': '♕', 'K': '♔',
199
+ 'p': '♟', 'n': '♞', 'b': '♝', 'r': '♜', 'q': '♛', 'k': '♚'
200
+ }
201
+
202
+ char = symbols[piece.symbol()]
203
+
204
+ # If using default font (freesansbold), unicode chess pieces often render as missing boxes.
205
+ # We fallback to standard English letters to guarantee visibility.
206
+ if is_default:
207
+ char = piece.symbol().upper()
208
+
209
+ color = (255, 255, 255) if piece.color == chess.WHITE else (30, 30, 30)
210
+ outline_color = (0, 0, 0) if piece.color == chess.WHITE else (255, 255, 255)
211
+
212
+ text = font.render(char, True, color)
213
+ text_rect = text.get_rect(center=(SQUARE_SIZE//2, SQUARE_SIZE//2))
214
+
215
+ # Draw outline for better visibility
216
+ for dx, dy in [(-1,-1), (-1,1), (1,-1), (1,1), (-2,0), (2,0), (0,-2), (0,2)]:
217
+ outline = font.render(char, True, outline_color)
218
+ surf.blit(outline, text_rect.move(dx, dy))
219
+
220
+ surf.blit(text, text_rect)
221
+ return surf
222
+
223
+ # ----------------------------------------------------------------------
224
+ # Model Inference (Async Worker)
225
+ # ----------------------------------------------------------------------
226
+ class ModelInference:
227
+ def __init__(self, checkpoint_path: str):
228
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
229
+ self.model = ChessNet().to(self.device).to(memory_format=torch.channels_last)
230
+
231
+ if os.path.exists(checkpoint_path):
232
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
233
+ state_dict = checkpoint['model_state_dict']
234
+ # Handle compiled model prefix
235
+ new_state_dict = {}
236
+ for k, v in state_dict.items():
237
+ name = k.replace('_orig_mod.', '')
238
+ new_state_dict[name] = v
239
+ self.model.load_state_dict(new_state_dict)
240
+ print(f"Loaded checkpoint from {checkpoint_path}")
241
+ else:
242
+ print(f"Warning: Checkpoint '{checkpoint_path}' not found. Using untrained weights.")
243
+
244
+ self.model.eval()
245
+ self.queue = queue.Queue()
246
+ self.running = True
247
+ self.thread = threading.Thread(target=self._worker, daemon=True)
248
+ self.thread.start()
249
+
250
+ def _worker(self):
251
+ state_np = np.zeros((1, 14, 8, 8), dtype=np.float32)
252
+ while self.running:
253
+ try:
254
+ board_fen, callback = self.queue.get(timeout=0.1)
255
+ except queue.Empty:
256
+ continue
257
+
258
+ board = chess.Board(board_fen)
259
+ boards_to_tensor_vectorized([board], state_np)
260
+ tensor = torch.tensor(state_np, dtype=torch.float32, device=self.device).to(memory_format=torch.channels_last)
261
+
262
+ with torch.no_grad():
263
+ if hasattr(torch.amp, 'autocast'):
264
+ with torch.amp.autocast(self.device):
265
+ logits, value = self.model(tensor)
266
+ else:
267
+ with torch.cuda.amp.autocast():
268
+ logits, value = self.model(tensor)
269
+
270
+ probs = torch.softmax(logits.to(torch.float32), dim=-1).cpu().numpy().flatten()
271
+
272
+ legal_moves = list(board.legal_moves)
273
+ legal_indices = [ACTION_MAPPER.move_to_idx[m.uci()] for m in legal_moves]
274
+
275
+ probs_filtered = np.zeros_like(probs)
276
+ probs_filtered[legal_indices] = probs[legal_indices]
277
+ s = probs_filtered.sum()
278
+ if s > 0:
279
+ probs_filtered /= s
280
+
281
+ top_indices = np.argsort(probs_filtered)[-3:][::-1]
282
+ top_moves = [(ACTION_MAPPER.idx_to_move[i], probs_filtered[i]) for i in top_indices if probs_filtered[i] > 0]
283
+
284
+ top_moves_obj = []
285
+ for uci, p in top_moves:
286
+ move = chess.Move.from_uci(uci)
287
+ if move in legal_moves:
288
+ top_moves_obj.append((move, p))
289
+
290
+ callback(top_moves_obj, value.item())
291
+
292
+ def predict_async(self, board: chess.Board, callback):
293
+ while not self.queue.empty():
294
+ try:
295
+ self.queue.get_nowait()
296
+ except queue.Empty:
297
+ pass
298
+ self.queue.put((board.fen(), callback))
299
+
300
+ def shutdown(self):
301
+ self.running = False
302
+ self.thread.join()
303
+
304
+ # ----------------------------------------------------------------------
305
+ # Main GUI Application
306
+ # ----------------------------------------------------------------------
307
+ class ChessApp:
308
+ def __init__(self, model_checkpoint: str):
309
+ if not HAS_DISPLAY:
310
+ print("\nExiting because no GUI display is available.")
311
+ sys.exit(1)
312
+
313
+ pygame.init()
314
+ self.screen = pygame.display.set_mode((WIDTH, HEIGHT))
315
+ pygame.display.set_caption("GRPO Chess Agent - Real-Time Predictions")
316
+ self.clock = pygame.time.Clock()
317
+ self.board = chess.Board()
318
+ self.selected_square = None
319
+ self.valid_moves = []
320
+ self.predicted_arrows = []
321
+ self.prediction_value = 0.0
322
+ self.piece_images = {}
323
+ self._load_pieces()
324
+ self.inference = ModelInference(model_checkpoint)
325
+ self.running = True
326
+ self.update_predictions()
327
+
328
+ def _load_pieces(self):
329
+ for piece_type in chess.PIECE_TYPES:
330
+ for color in (chess.WHITE, chess.BLACK):
331
+ piece = chess.Piece(piece_type, color)
332
+ self.piece_images[(piece_type, color)] = create_piece_surface(piece)
333
+
334
+ def square_to_xy(self, square: chess.Square) -> Tuple[int, int]:
335
+ file_idx = chess.square_file(square)
336
+ rank_idx = 7 - chess.square_rank(square)
337
+ return file_idx * SQUARE_SIZE, rank_idx * SQUARE_SIZE
338
+
339
+ def xy_to_square(self, x: int, y: int) -> Optional[chess.Square]:
340
+ file_idx = x // SQUARE_SIZE
341
+ rank_idx = 7 - (y // SQUARE_SIZE)
342
+ if 0 <= file_idx < 8 and 0 <= rank_idx < 8:
343
+ return chess.square(file_idx, rank_idx)
344
+ return None
345
+
346
+ def draw_board(self):
347
+ for row in range(8):
348
+ for col in range(8):
349
+ color = LIGHT if (row + col) % 2 == 0 else DARK
350
+ rect = pygame.Rect(col * SQUARE_SIZE, row * SQUARE_SIZE, SQUARE_SIZE, SQUARE_SIZE)
351
+ pygame.draw.rect(self.screen, color, rect)
352
+
353
+ def draw_pieces(self):
354
+ for square in chess.SQUARES:
355
+ piece = self.board.piece_at(square)
356
+ if piece:
357
+ x, y = self.square_to_xy(square)
358
+ self.screen.blit(self.piece_images[(piece.piece_type, piece.color)], (x, y))
359
+
360
+ def draw_highlights(self):
361
+ if self.selected_square is not None:
362
+ x, y = self.square_to_xy(self.selected_square)
363
+ s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA)
364
+ s.fill((255, 255, 0, 100))
365
+ self.screen.blit(s, (x, y))
366
+ for move in self.valid_moves:
367
+ x, y = self.square_to_xy(move.to_square)
368
+ s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA)
369
+ s.fill((0, 255, 0, 60))
370
+ self.screen.blit(s, (x, y))
371
+
372
+ def draw_arrows(self):
373
+ for move, prob in self.predicted_arrows:
374
+ start_x, start_y = self.square_to_xy(move.from_square)
375
+ end_x, end_y = self.square_to_xy(move.to_square)
376
+ start = (start_x + SQUARE_SIZE//2, start_y + SQUARE_SIZE//2)
377
+ end = (end_x + SQUARE_SIZE//2, end_y + SQUARE_SIZE//2)
378
+
379
+ alpha = max(100, int(255 * prob))
380
+ color = (ARROW_COLOR[0], ARROW_COLOR[1], ARROW_COLOR[2], alpha)
381
+ width = max(3, int(12 * prob))
382
+
383
+ arrow_surface = pygame.Surface((WIDTH, HEIGHT), pygame.SRCALPHA)
384
+ pygame.draw.line(arrow_surface, color, start, end, width)
385
+
386
+ angle = math.atan2(end[1]-start[1], end[0]-start[0])
387
+ arrow_len = max(15, int(25 * prob))
388
+ arrow_angle = math.pi/6
389
+ x2 = end[0] - arrow_len * math.cos(angle - arrow_angle)
390
+ y2 = end[1] - arrow_len * math.sin(angle - arrow_angle)
391
+ x3 = end[0] - arrow_len * math.cos(angle + arrow_angle)
392
+ y3 = end[1] - arrow_len * math.sin(angle + arrow_angle)
393
+ pygame.draw.polygon(arrow_surface, color, [end, (x2, y2), (x3, y3)])
394
+
395
+ self.screen.blit(arrow_surface, (0, 0))
396
+
397
+ def draw_info(self):
398
+ font = pygame.font.Font(None, 36)
399
+ text = f"Eval Value: {self.prediction_value:.3f} (W/B)"
400
+ surf = font.render(text, True, TEXT_COLOR)
401
+
402
+ bg_rect = surf.get_rect(topleft=(10, HEIGHT - 40))
403
+ bg_rect.inflate_ip(10, 10)
404
+ pygame.draw.rect(self.screen, (255, 255, 255), bg_rect)
405
+ pygame.draw.rect(self.screen, (0, 0, 0), bg_rect, 2)
406
+
407
+ self.screen.blit(surf, (15, HEIGHT - 35))
408
+
409
+ def update_predictions(self):
410
+ def callback(top_moves, value):
411
+ self.predicted_arrows = top_moves
412
+ self.prediction_value = value
413
+ self.inference.predict_async(self.board, callback)
414
+
415
+ def handle_click(self, pos):
416
+ square = self.xy_to_square(*pos)
417
+ if square is None:
418
+ return
419
+
420
+ if self.selected_square is None:
421
+ piece = self.board.piece_at(square)
422
+ if piece and piece.color == self.board.turn:
423
+ self.selected_square = square
424
+ self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square]
425
+ else:
426
+ move = chess.Move(self.selected_square, square)
427
+ if chess.square_rank(square) in (0, 7) and self.board.piece_at(self.selected_square).piece_type == chess.PAWN:
428
+ move = chess.Move(self.selected_square, square, promotion=chess.QUEEN)
429
+
430
+ if move in self.board.legal_moves:
431
+ self.board.push(move)
432
+ self.selected_square = None
433
+ self.valid_moves = []
434
+ self.update_predictions()
435
+ else:
436
+ piece = self.board.piece_at(square)
437
+ if piece and piece.color == self.board.turn:
438
+ self.selected_square = square
439
+ self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square]
440
+ else:
441
+ self.selected_square = None
442
+ self.valid_moves = []
443
+
444
+ def run(self):
445
+ while self.running:
446
+ for event in pygame.event.get():
447
+ if event.type == pygame.QUIT:
448
+ self.running = False
449
+ elif event.type == pygame.MOUSEBUTTONDOWN:
450
+ if event.button == 1:
451
+ self.handle_click(event.pos)
452
+ elif event.type == pygame.KEYDOWN:
453
+ if event.key == pygame.K_r:
454
+ self.board.reset()
455
+ self.selected_square = None
456
+ self.valid_moves = []
457
+ self.update_predictions()
458
+
459
+ self.screen.fill((0, 0, 0))
460
+ self.draw_board()
461
+ self.draw_highlights()
462
+ self.draw_arrows()
463
+ self.draw_pieces()
464
+ self.draw_info()
465
+
466
+ pygame.display.flip()
467
+ self.clock.tick(FPS)
468
+
469
+ self.inference.shutdown()
470
+ pygame.quit()
471
+ sys.exit()
472
+
473
+ if __name__ == "__main__":
474
+ checkpoint_path = "./checkpoints/latest.pt"
475
+ if len(sys.argv) >= 2 and not sys.argv[1].startswith('-'):
476
+ checkpoint_path = sys.argv[1]
477
+
478
+ app = ChessApp(checkpoint_path)
479
+ if HAS_DISPLAY:
480
+ app.run()