dechantoine commited on
Commit
ba1e298
·
verified ·
1 Parent(s): 1bc5f89

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. src/data/data_utils.py +424 -0
  3. src/models +63 -0
app.py CHANGED
@@ -10,7 +10,7 @@ import numpy as np
10
  from src.data.data_utils import clean_board
11
  from src.engine.agents.policies import beam_search, eval_board, one_depth_eval
12
  from src.engine.agents.viz_utils import plot_save_beam_search, save_svg, board_to_svg
13
- from .models.multi_input_conv import MultiInputConv
14
 
15
  # TEMP_DIR = "./demos/temp/"
16
  CHKPT = "checkpoint.pt"
 
10
  from src.data.data_utils import clean_board
11
  from src.engine.agents.policies import beam_search, eval_board, one_depth_eval
12
  from src.engine.agents.viz_utils import plot_save_beam_search, save_svg, board_to_svg
13
+ from src.models.model_space import MultiInputConv
14
 
15
  # TEMP_DIR = "./demos/temp/"
16
  CHKPT = "checkpoint.pt"
src/data/data_utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+
4
+ import chess.pgn
5
+ import numpy as np
6
+ import torch
7
+ from loguru import logger
8
+
9
+ dict_pieces = {
10
+ "white": {
11
+ "R": "rook",
12
+ "N": "knight",
13
+ "B": "bishop",
14
+ "Q": "queen",
15
+ "K": "king",
16
+ "P": "pawn",
17
+ },
18
+ "black": {
19
+ "r": "rook",
20
+ "n": "knight",
21
+ "b": "bishop",
22
+ "q": "queen",
23
+ "k": "king",
24
+ "p": "pawn",
25
+ },
26
+ }
27
+
28
+
29
+ def arrays_to_lists(data):
30
+ """Recursively transform all numpy arrays in a nested structure into lists.
31
+
32
+ Args:
33
+ data: The nested structure containing numpy arrays.
34
+
35
+ Returns:
36
+ The nested structure with all numpy arrays converted to lists.
37
+
38
+ """
39
+ if isinstance(data, np.ndarray):
40
+ data = data.tolist()
41
+ return [arrays_to_lists(item) for item in data]
42
+ elif isinstance(data, list):
43
+ return [arrays_to_lists(item) for item in data]
44
+ else:
45
+ return data
46
+
47
+
48
+ @logger.catch
49
+ def clean_board(board: str) -> chess.Board:
50
+ """Clean the board string and return a chess.Board object.
51
+
52
+ Args:
53
+ board (str): board string
54
+
55
+ Returns:
56
+ chess.Board: chess.Board object
57
+
58
+ """
59
+ board = board.replace("'", "")
60
+ board = board.replace('"', "")
61
+
62
+ try:
63
+ board = chess.Board(fen=board)
64
+ except ValueError:
65
+ try:
66
+ game = chess.pgn.read_game(io.StringIO(board))
67
+ board = game.board()
68
+ for move in game.mainline_moves():
69
+ board.push(move)
70
+ except ValueError:
71
+ raise ValueError("Invalid FEN or PGN board provided.")
72
+
73
+ return board
74
+
75
+
76
+ @logger.catch
77
+ def format_board(board: chess.Board) -> str:
78
+ """Format a board to a compact string.
79
+
80
+ Args:
81
+ board (chess.Board): board to format.
82
+
83
+ Returns:
84
+ str: formatted board.
85
+
86
+ """
87
+ return str(board).replace("\n", "").replace(" ", "")
88
+
89
+
90
+ @logger.catch
91
+ def string_to_array(str_board: str, is_white: bool = True) -> np.array:
92
+ """Convert a string compact board to a numpy array. The array is of shape (6, 8, 8) and is the one-hot encoding of
93
+ the player pieces.
94
+
95
+ Args:
96
+ str_board (str): compact board.
97
+ is_white (bool, optional): True if white pieces, False otherwise. Defaults to True.
98
+
99
+ Returns:
100
+ np.array: numpy array of shape (6, 8, 8).
101
+
102
+ """
103
+ list_board = list(str_board)
104
+ key = "white" if is_white else "black"
105
+ return np.array(
106
+ [
107
+ np.reshape([1 * (p == piece) for p in list_board], newshape=(8, 8))
108
+ for piece in list(dict_pieces[key])
109
+ ]
110
+ )
111
+
112
+
113
+ def board_to_list_index(board: chess.Board) -> list:
114
+ """Convert a chess board to a list of indexes.
115
+
116
+ Args:
117
+ board (chess.Board): board to convert.
118
+
119
+ Returns:
120
+ list: list of indexes.
121
+
122
+ """
123
+ list_board = list(format_board(board))
124
+ idx_white = [np.flatnonzero([1 * (p == piece) for p in list_board]).tolist()
125
+ for piece in list(dict_pieces["white"])]
126
+ idx_black = [np.flatnonzero([1 * (p == piece) for p in list_board]).tolist()
127
+ for piece in list(dict_pieces["black"])]
128
+
129
+ idx_white = [idx if len(idx) > 0 else None for idx in idx_white]
130
+ idx_black = [idx if len(idx) > 0 else None for idx in idx_black]
131
+
132
+ active_color = 1 * (board.turn == chess.WHITE)
133
+
134
+ castling = [board.has_kingside_castling_rights(chess.WHITE) * 1,
135
+ board.has_queenside_castling_rights(chess.WHITE) * 1,
136
+ board.has_kingside_castling_rights(chess.BLACK) * 1,
137
+ board.has_queenside_castling_rights(chess.BLACK) * 1]
138
+
139
+ en_passant = board.ep_square if board.ep_square else -1
140
+
141
+ list_indexes = idx_white + idx_black + [active_color] + [castling] + [en_passant] + [board.halfmove_clock] + [
142
+ board.fullmove_number]
143
+
144
+ return list_indexes
145
+
146
+
147
+ def list_index_to_fen(idxs: list) -> str:
148
+ """Convert a list of indexes to a FEN string.
149
+
150
+ Args:
151
+ idxs (list): list of indexes.
152
+
153
+ Returns:
154
+ str: FEN string.
155
+
156
+ """
157
+ idx_white = idxs[:6]
158
+ idx_black = idxs[6:12]
159
+ active_color, castling, en_passant, halfmove, fullmove = idxs[12:]
160
+ list_board = ["."] * 64
161
+ for i, piece in enumerate(list(dict_pieces["white"])):
162
+ if idx_white[i]:
163
+ for idx in idx_white[i]:
164
+ list_board[idx] = piece
165
+ for i, piece in enumerate(list(dict_pieces["black"])):
166
+ if idx_black[i]:
167
+ for idx in idx_black[i]:
168
+ list_board[idx] = piece
169
+ for k in range(7):
170
+ list_board.insert(8 * (k + 1) + k, "/")
171
+
172
+ active_color = "w" if active_color else "b"
173
+
174
+ str_castling = ["K" if castling[0] else "",
175
+ "Q" if castling[1] else "",
176
+ "k" if castling[2] else "",
177
+ "q" if castling[3] else ""]
178
+ str_castling = "".join(str_castling)
179
+ str_castling = str_castling if str_castling else "-"
180
+
181
+ en_passant = chess.SQUARE_NAMES[en_passant] if en_passant != -1 else "-"
182
+
183
+ fen = ("".join(list_board) + " "
184
+ + active_color + " "
185
+ + str_castling + " "
186
+ + str(en_passant) + " "
187
+ + str(halfmove) + " "
188
+ + str(fullmove))
189
+ fen = re.sub(r'\.+', lambda m: str(len(m.group())), fen)
190
+ return fen
191
+
192
+
193
+ def list_index_to_tensor(idxs: list) -> np.array:
194
+ """Convert a list of indexes to a tensor.
195
+
196
+ Args:
197
+ idxs (list): list of indexes.
198
+
199
+ Returns:
200
+ np.array: tensor.
201
+
202
+ """
203
+ tensor_pieces = np.zeros((12, 8 * 8), dtype=np.int8)
204
+ for i, list_idx in enumerate(idxs[:12]):
205
+ if list_idx:
206
+ for idx in list_idx:
207
+ tensor_pieces[i, idx] = 1
208
+ tensor_pieces = tensor_pieces.reshape((12, 8, 8))
209
+
210
+ return tensor_pieces
211
+
212
+
213
+ @logger.catch
214
+ def uci_to_coordinates(move: chess.Move) -> tuple:
215
+ """Convert a move in UCI format to coordinates.
216
+
217
+ Args:
218
+ move (chess.Move): move to convert.
219
+
220
+ Returns:
221
+ tuple: coordinates of the origin square and coordinates of the destination square.
222
+
223
+ """
224
+ return (7 - move.from_square // 8, move.from_square % 8), (
225
+ 7 - move.to_square // 8,
226
+ move.to_square % 8,
227
+ )
228
+
229
+
230
+ @logger.catch
231
+ def moves_to_tensor(moves: list[chess.Move]) -> np.array:
232
+ """Convert a list of moves to a (8*8, 8*8) tensor. For each origin square, the tensor contains a vector of size 8*8
233
+ with 1 at the index of the destination squares in list of moves, 0 otherwise.
234
+
235
+ Args:
236
+ moves (list[chess.Move]): list of moves.
237
+
238
+ Returns:
239
+ np.array: tensor of possible moves from each square.
240
+
241
+ """
242
+ moves_tensor = np.zeros(shape=(8 * 8, 8 * 8), dtype=np.int8)
243
+ for move in moves:
244
+ from_coordinates, to_coordinates = uci_to_coordinates(move)
245
+ moves_tensor[
246
+ from_coordinates[0] * 8 + from_coordinates[1],
247
+ to_coordinates[0] * 8 + to_coordinates[1],
248
+ ] = 1
249
+ return moves_tensor
250
+
251
+
252
+ @logger.catch
253
+ def board_to_tensor(board: chess.Board) -> tuple[np.array, np.array, np.array]:
254
+ """Convert a board to a tuple of shapes ((12, 8, 8), (1) , (4)). The tuple contains the one-hot encoding of the
255
+ board, the active color and the castling rights.
256
+
257
+ Args:
258
+ board (chess.Board): board to convert.
259
+
260
+ Returns:
261
+ tuple[np.array, np.array, np.array]: tuple of tensors.
262
+
263
+ """
264
+ list_board = list(format_board(board))
265
+
266
+ idx_white = [np.flatnonzero([1 * (p == piece) for p in list_board]).tolist()
267
+ for piece in list(dict_pieces["white"])]
268
+ idx_black = [np.flatnonzero([1 * (p == piece) for p in list_board]).tolist()
269
+ for piece in list(dict_pieces["black"])]
270
+
271
+ active_color = 1 * (board.turn == chess.WHITE)
272
+
273
+ castling = [board.has_kingside_castling_rights(chess.WHITE) * 1,
274
+ board.has_queenside_castling_rights(chess.WHITE) * 1,
275
+ board.has_kingside_castling_rights(chess.BLACK) * 1,
276
+ board.has_queenside_castling_rights(chess.BLACK) * 1]
277
+
278
+ return list_index_to_tensor(idx_white + idx_black), np.array([active_color]), np.array(castling)
279
+
280
+
281
+ @logger.catch
282
+ def batch_moves_to_tensor(batch_moves: list[list[chess.Move]]) -> np.array:
283
+ """Convert a batch of list of moves to a batch of (8*8, 8*8) tensors.
284
+
285
+ Args:
286
+ batch_moves (list[list[chess.Move]]): batch of list of moves.
287
+
288
+ Returns:
289
+ list[np.array]: batch of moves tensors.
290
+
291
+ """
292
+
293
+ return np.array([moves_to_tensor(moves) for moves in batch_moves])
294
+
295
+
296
+ @logger.catch
297
+ def batch_boards_to_tensor(
298
+ batch_boards: list[chess.Board]
299
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
300
+ """Convert a batch of boards to a batch of board tensors.
301
+
302
+ Args:
303
+ batch_boards (list[chess.Board]): batch of boards to convert.
304
+
305
+ Returns:
306
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tuple of tensors.
307
+
308
+ """
309
+ tensors = [board_to_tensor(board) for board in batch_boards]
310
+ return (torch.Tensor(np.array([tensors[i][0] for i in range(len(tensors))])),
311
+ torch.Tensor(np.array([tensors[i][1] for i in range(len(tensors))])),
312
+ torch.Tensor(np.array([tensors[i][2] for i in range(len(tensors))])))
313
+
314
+
315
+ @logger.catch
316
+ def game_to_legal_moves_tensor(game: chess.pgn.Game) -> np.array:
317
+ """Convert a game to a tensor of legal moves. The tensor is of shape (nb_moves, 8*8, 8*8) and contains a tensor of
318
+ legal moves for each move of the game.
319
+
320
+ Args:
321
+ game (chess.pgn.Game): game to convert.
322
+
323
+ Returns:
324
+ np.array: tensor of legal moves.
325
+
326
+ """
327
+ board = game.board()
328
+ boards = []
329
+ for move in game.mainline_moves():
330
+ board.push(move)
331
+ boards.append(board.copy())
332
+ legal_moves_tensors = batch_moves_to_tensor(
333
+ [list(board.legal_moves) for board in boards]
334
+ )
335
+ return np.array(legal_moves_tensors)
336
+
337
+
338
+ @logger.catch
339
+ def game_to_board_tensor(game: chess.pgn.Game) -> np.array:
340
+ """Convert a game to a tensor of boards. The tensor is of shape (nb_moves, 12, 8, 8) and contains a board tensor for
341
+ each move of the game.
342
+
343
+ Args:
344
+ game (chess.pgn.Game): game to convert.
345
+
346
+ Returns:
347
+ np.array: tensor of boards.
348
+
349
+ """
350
+ board = game.board()
351
+ boards = []
352
+ for move in game.mainline_moves():
353
+ board.push(move)
354
+ boards.append(board.copy())
355
+ board_tensors = batch_boards_to_tensor(boards)
356
+ return np.array(board_tensors)
357
+
358
+
359
+ @logger.catch(exclude=ValueError)
360
+ def result_to_tensor(result: str) -> np.array:
361
+ """Convert a game result to a tensor. The tensor is of shape (1,) and contains 1 for a white win, 0 for a draw and
362
+ -1 for a white loss.
363
+
364
+ Args:
365
+ result (str): game result.
366
+
367
+ Returns:
368
+ np.array: tensor of game result.
369
+
370
+ """
371
+ if result == "1-0":
372
+ return np.array([1], dtype=np.int8)
373
+ elif result == "0-1":
374
+ return np.array([-1], dtype=np.int8)
375
+ elif result == "1/2-1/2":
376
+ return np.array([0], dtype=np.int8)
377
+ else:
378
+ raise ValueError(f"Result {result} not supported.")
379
+
380
+
381
+ @logger.catch
382
+ def batch_results_to_tensor(batch_results: list[str]) -> np.array:
383
+ """Convert a batch of game results to a tensor. The tensor is of shape (nb_games, 1) and contains a tensor of game
384
+ result for each game of the batch.
385
+
386
+ Args:
387
+ batch_results (list[str]): batch of game results.
388
+
389
+ Returns:
390
+ np.array: tensor of game results.
391
+
392
+ """
393
+ return np.array([result_to_tensor(result) for result in batch_results])
394
+
395
+
396
+ @logger.catch
397
+ def read_boards_from_pgn(pgn_file: str, start_move: int = 0, end_move: int = 0) -> list[chess.Board]:
398
+ """Read boards from a PGN file.
399
+
400
+ Args:
401
+ pgn_file (str): path to the PGN file
402
+ start_move (int): move to start from in each game
403
+ end_move (int): move to end at in each game (counting from the end)
404
+
405
+ Returns:
406
+ list[chess.Board]: list of boards
407
+
408
+ """
409
+ pgn = open(pgn_file)
410
+ game = chess.pgn.read_game(pgn)
411
+ boards = []
412
+
413
+ while game:
414
+ board = game.board()
415
+ mainline = list(game.mainline_moves())
416
+ end_index = len(mainline) - end_move
417
+
418
+ for i, move in enumerate(mainline[:end_index]):
419
+ board.push(move)
420
+ if start_move <= i:
421
+ boards.append(board.copy())
422
+ game = chess.pgn.read_game(pgn)
423
+
424
+ return boards
src/models ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+
3
+ import torch
4
+ from loguru import logger
5
+
6
+
7
+ class MultiInputConv(torch.nn.Module):
8
+ @logger.catch
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.flatten = torch.nn.Flatten()
12
+ self.conv_long = torch.nn.Sequential(
13
+ torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=15, padding=7, stride=2),
14
+ torch.nn.LeakyReLU(),
15
+ torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2),
16
+ torch.nn.LeakyReLU(),
17
+ )
18
+ self.conv_middle = torch.nn.Sequential(
19
+ torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=9, padding=4, stride=2),
20
+ torch.nn.LeakyReLU(),
21
+ torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2),
22
+ torch.nn.LeakyReLU(),
23
+ )
24
+ self.conv_short = torch.nn.Sequential(
25
+ torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=5, padding=2, stride=2),
26
+ torch.nn.LeakyReLU(),
27
+ torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2),
28
+ torch.nn.LeakyReLU(),
29
+ )
30
+ self.linear_relu_stack = torch.nn.Sequential(
31
+ torch.nn.Linear(in_features=(4 * 2 * 2) + (4 * 2 * 2) + (4 * 2 * 2) + 1 + 4, out_features=16),
32
+ torch.nn.LeakyReLU(),
33
+ torch.nn.Linear(in_features=16, out_features=1),
34
+ )
35
+
36
+
37
+ @logger.catch
38
+ def forward(self, x):
39
+ board, color, castling = x
40
+ board = board.float()
41
+ color = color.float()
42
+ castling = castling.float()
43
+
44
+ long = self.conv_long(board)
45
+ long = self.flatten(long)
46
+
47
+ middle = self.conv_middle(board)
48
+ middle = self.flatten(middle)
49
+
50
+ short = self.conv_short(board)
51
+ short = self.flatten(short)
52
+
53
+ x = torch.cat((long, middle, short, color, castling), dim=1)
54
+
55
+ score = self.linear_relu_stack(x)
56
+ return score
57
+
58
+ @logger.catch
59
+ def model_hash(self) -> str:
60
+ """Get the hash of the model."""
61
+ return hashlib.md5(
62
+ (str(self.linear_relu_stack) + str(self.flatten)).encode()
63
+ ).hexdigest()