Maxlegrec commited on
Commit
42d76c2
·
verified ·
1 Parent(s): ba34655

Update model architecture: d_ff=1024, new weights from merged7.pt

Browse files
Files changed (3) hide show
  1. config.json +9 -6
  2. model.safetensors +2 -2
  3. modeling_chessbot.py +423 -461
config.json CHANGED
@@ -1,15 +1,18 @@
1
  {
2
- "d_ff": 736,
 
 
 
 
 
 
 
3
  "d_model": 512,
4
  "max_position_embeddings": 64,
5
  "model_type": "chessbot",
6
- "architectures": ["ChessBotModel"],
7
- "auto_map": {
8
- "AutoModel": "modeling_chessbot.ChessBotModel",
9
- "AutoConfig": "modeling_chessbot.ChessBotConfig"
10
- },
11
  "num_heads": 8,
12
  "num_layers": 10,
 
13
  "transformers_version": "4.53.1",
14
  "vocab_size": 1929
15
  }
 
1
  {
2
+ "architectures": [
3
+ "ChessBotModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_chessbot.ChessBotConfig",
7
+ "AutoModel": "modeling_chessbot.ChessBotModel"
8
+ },
9
+ "d_ff": 1024,
10
  "d_model": 512,
11
  "max_position_embeddings": 64,
12
  "model_type": "chessbot",
 
 
 
 
 
13
  "num_heads": 8,
14
  "num_layers": 10,
15
+ "torch_dtype": "float32",
16
  "transformers_version": "4.53.1",
17
  "vocab_size": 1929
18
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:68731e6162fc025c204e9e0de091e1309d77a0edd982a19824934ffb8e73c2f3
3
- size 126985128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:824ed0a0d945ebf519eee41755d7a7d29487bbdc92e49b94aaf97de6105f7b17
3
+ size 138793144
modeling_chessbot.py CHANGED
@@ -1,14 +1,7 @@
1
  """
2
- Standalone ChessBot Chess Model
3
 
4
- This file contains all the necessary code to run the ChessBot model
5
- without requiring the HFChessRL package installation.
6
-
7
- Requirements:
8
- - torch>=2.0.0
9
- - transformers>=4.30.0
10
- - python-chess>=1.10.0
11
- - numpy>=1.21.0
12
  """
13
 
14
  import torch
@@ -20,6 +13,350 @@ from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoMode
20
  from transformers.modeling_outputs import BaseModelOutput
21
  from typing import Optional, Tuple
22
  import math
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # Configuration class
@@ -34,7 +371,7 @@ class ChessBotConfig(PretrainedConfig):
34
  self,
35
  num_layers: int = 10,
36
  d_model: int = 512,
37
- d_ff: int = 736,
38
  num_heads: int = 8,
39
  vocab_size: int = 1929,
40
  max_position_embeddings: int = 64,
@@ -49,339 +386,6 @@ class ChessBotConfig(PretrainedConfig):
49
  self.max_position_embeddings = max_position_embeddings
50
 
51
 
52
- # FEN encoding function
53
- def fen_to_tensor(fen: str):
54
- """
55
- Convert FEN string to tensor representation for the model.
56
- """
57
- board = chess.Board(fen)
58
- tensor = np.zeros((8, 8, 19), dtype=np.float32)
59
-
60
- # Piece mapping
61
- piece_map = {
62
- 'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
63
- 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
64
- }
65
-
66
- # Fill piece positions
67
- for square in chess.SQUARES:
68
- piece = board.piece_at(square)
69
- if piece:
70
- row = 7 - (square // 8) # Flip vertically for proper orientation
71
- col = square % 8
72
- tensor[row, col, piece_map[piece.symbol()]] = 1.0
73
-
74
- # Add metadata channels
75
- # Channel 12: White to move
76
- if board.turn == chess.WHITE:
77
- tensor[:, :, 12] = 1.0
78
-
79
- # Channel 13: Black to move
80
- if board.turn == chess.BLACK:
81
- tensor[:, :, 13] = 1.0
82
-
83
- # Castling rights
84
- if board.has_kingside_castling_rights(chess.WHITE):
85
- tensor[:, :, 14] = 1.0
86
- if board.has_queenside_castling_rights(chess.WHITE):
87
- tensor[:, :, 15] = 1.0
88
- if board.has_kingside_castling_rights(chess.BLACK):
89
- tensor[:, :, 16] = 1.0
90
- if board.has_queenside_castling_rights(chess.BLACK):
91
- tensor[:, :, 17] = 1.0
92
-
93
- # En passant
94
- if board.ep_square is not None:
95
- ep_row = 7 - (board.ep_square // 8)
96
- ep_col = board.ep_square % 8
97
- tensor[ep_row, ep_col, 18] = 1.0
98
-
99
- return tensor
100
-
101
-
102
- # Complete policy index with all 1929 moves
103
- policy_index = [
104
- "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
105
- "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
106
- "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1",
107
- "b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3",
108
- "b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7",
109
- "b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1",
110
- "c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3",
111
- "c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
112
- "d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2",
113
- "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4",
114
- "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1",
115
- "e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2",
116
- "e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4",
117
- "e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1",
118
- "f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3",
119
- "f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
120
- "f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1",
121
- "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3",
122
- "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8",
123
- "h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2",
124
- "h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6",
125
- "h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2",
126
- "a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3",
127
- "a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
128
- "a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2",
129
- "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4",
130
- "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7",
131
- "b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2",
132
- "c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3",
133
- "c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6",
134
- "c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1",
135
- "d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
136
- "d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5",
137
- "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1",
138
- "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2",
139
- "e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4",
140
- "e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1",
141
- "f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2",
142
- "f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4",
143
- "f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
144
- "f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2",
145
- "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4",
146
- "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8",
147
- "g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2",
148
- "h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5",
149
- "h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1",
150
- "a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3",
151
- "a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
152
- "a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1",
153
- "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3",
154
- "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5",
155
- "b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1",
156
- "c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3",
157
- "c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4",
158
- "c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6",
159
- "c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
160
- "d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3",
161
- "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5",
162
- "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7",
163
- "d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2",
164
- "e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3",
165
- "e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5",
166
- "e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1",
167
- "f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
168
- "f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4",
169
- "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6",
170
- "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2",
171
- "g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3",
172
- "g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5",
173
- "g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1",
174
- "h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3",
175
- "h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
176
- "h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2",
177
- "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4",
178
- "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7",
179
- "a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3",
180
- "b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4",
181
- "b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6",
182
- "b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2",
183
- "c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
184
- "c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5",
185
- "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8",
186
- "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2",
187
- "d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4",
188
- "d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6",
189
- "d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8",
190
- "e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3",
191
- "e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
192
- "e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6",
193
- "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1",
194
- "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3",
195
- "f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4",
196
- "f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6",
197
- "f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2",
198
- "g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4",
199
- "g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
200
- "g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1",
201
- "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4",
202
- "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6",
203
- "h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2",
204
- "a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5",
205
- "a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7",
206
- "a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3",
207
- "b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
208
- "b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7",
209
- "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2",
210
- "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4",
211
- "c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6",
212
- "c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7",
213
- "c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3",
214
- "d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5",
215
- "d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
216
- "d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8",
217
- "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3",
218
- "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5",
219
- "e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6",
220
- "e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8",
221
- "f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3",
222
- "f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5",
223
- "f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
224
- "f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2",
225
- "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4",
226
- "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6",
227
- "g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1",
228
- "h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4",
229
- "h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6",
230
- "h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2",
231
- "a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
232
- "a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7",
233
- "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3",
234
- "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5",
235
- "b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7",
236
- "b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2",
237
- "c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5",
238
- "c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6",
239
- "c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
240
- "c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3",
241
- "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5",
242
- "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7",
243
- "d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8",
244
- "e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4",
245
- "e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6",
246
- "e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7",
247
- "e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
248
- "f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5",
249
- "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6",
250
- "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8",
251
- "f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3",
252
- "g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6",
253
- "g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7",
254
- "g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2",
255
- "h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
256
- "h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7",
257
- "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3",
258
- "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7",
259
- "a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8",
260
- "b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5",
261
- "b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7",
262
- "b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8",
263
- "c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
264
- "c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7",
265
- "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8",
266
- "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4",
267
- "d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6",
268
- "d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8",
269
- "d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4",
270
- "e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6",
271
- "e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
272
- "e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2",
273
- "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5",
274
- "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7",
275
- "f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1",
276
- "g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5",
277
- "g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7",
278
- "g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1",
279
- "h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
280
- "h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7",
281
- "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2",
282
- "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6",
283
- "a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8",
284
- "a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5",
285
- "b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7",
286
- "b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2",
287
- "c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
288
- "c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8",
289
- "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4",
290
- "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6",
291
- "d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8",
292
- "d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5",
293
- "e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7",
294
- "e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8",
295
- "e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
296
- "f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7",
297
- "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1",
298
- "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6",
299
- "g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8",
300
- "g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2",
301
- "h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6",
302
- "h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8",
303
- "h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q",
304
- "b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b",
305
- "c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r",
306
- "c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q",
307
- "d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b",
308
- "e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r",
309
- "f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q",
310
- "g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b",
311
- "h7h8q", "h7h8r", "h7h8b", #add the promotions for black
312
- "a2a1q","a2a1r","a2a1b","a2b1q","a2b1r","a2b1b",
313
- "b2a1q","b2a1r","b2a1b","b2b1q","b2b1r","b2b1b","b2c1q","b2c1r","b2c1b",
314
- "c2b1q","c2b1r","c2b1b","c2c1q","c2c1r","c2c1b","c2d1q","c2d1r","c2d1b",
315
- "d2c1q","d2c1r","d2c1b","d2d1q","d2d1r","d2d1b","d2e1q","d2e1r","d2e1b",
316
- "e2d1q","e2d1r","e2d1b","e2e1q","e2e1r","e2e1b","e2f1q","e2f1r","e2f1b",
317
- "f2e1q","f2e1r","f2e1b","f2f1q","f2f1r","f2f1b","f2g1q","f2g1r","f2g1b",
318
- "g2f1q","g2f1r","g2f1b","g2g1q","g2g1r","g2g1b","g2h1q","g2h1r","g2h1b",
319
- "h2g1q","h2g1r","h2g1b","h2h1q","h2h1r","h2h1b",#add special tokens
320
- "<thinking>","</thinking>","end_variation","end","padding_token"
321
- ]
322
-
323
-
324
-
325
- # Attention mechanism
326
- class RelativeMultiHeadAttention2(nn.Module):
327
- def __init__(self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1):
328
- super().__init__()
329
- assert d_model % num_heads == 0
330
- self.d_model = d_model
331
- self.num_heads = num_heads
332
- self.d_head = d_model // num_heads
333
- self.sqrt_dim = math.sqrt(d_model)
334
-
335
- self.query_proj = nn.Linear(d_model, d_model)
336
- self.key_proj = nn.Linear(d_model, d_model)
337
- self.value_proj = nn.Linear(d_model, d_model)
338
- self.pos_proj = nn.Linear(d_model, d_model)
339
- self.out_proj = nn.Linear(d_model, d_model)
340
-
341
- self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
342
- self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
343
- torch.nn.init.xavier_uniform_(self.u_bias)
344
- torch.nn.init.xavier_uniform_(self.v_bias)
345
- self.dropout = nn.Dropout(dropout_p)
346
-
347
- def forward(self, query, key, value, pos_embedding, mask=None):
348
- batch_size = value.size(0)
349
-
350
- query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
351
- key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
352
- value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
353
-
354
- pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
355
-
356
- content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
357
- pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
358
- pos_score = self._compute_relative_positional_encoding(pos_score)
359
-
360
- score = (content_score + pos_score) / self.sqrt_dim
361
-
362
- if mask is not None:
363
- mask = mask.unsqueeze(1)
364
- score.masked_fill_(mask, -1e9)
365
-
366
- attn = F.softmax(score, -1)
367
- attn = self.dropout(attn)
368
-
369
- context = torch.matmul(attn, value).transpose(1, 2)
370
- context = context.contiguous().view(batch_size, -1, self.d_model)
371
-
372
- return self.out_proj(context)
373
-
374
- def _compute_relative_positional_encoding(self, pos_score):
375
- batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
376
- zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
377
- padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
378
-
379
- padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
380
- pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
381
-
382
- return pos_score
383
-
384
-
385
  # Model components
386
  class MaGating(nn.Module):
387
  def __init__(self, d_model):
@@ -409,8 +413,8 @@ class EncoderLayer(nn.Module):
409
  x = self.norm1(x)
410
 
411
  y = self.ff1(x)
412
- y = self.ff2(y)
413
  y = self.gelu(y)
 
414
  y = y + x
415
  y = self.norm2(y)
416
 
@@ -495,19 +499,19 @@ class ChessBotPreTrainedModel(PreTrainedModel):
495
 
496
  class ChessBotModel(ChessBotPreTrainedModel):
497
  """
498
- HuggingFace compatible ChessBot Chess model with ALL original functionality
499
  """
500
 
501
  def __init__(self, config):
502
  super().__init__(config)
503
  self.config = config
504
 
505
- # Initialize exactly like the original BT4 model
506
  self.is_thinking_model = False
507
  self.d_model = config.d_model
508
  self.num_layers = config.num_layers
509
 
510
- # Model layers - same as original
511
  self.layers = nn.ModuleList([
512
  EncoderLayer(config.d_model, config.d_ff, config.num_heads)
513
  for _ in range(config.num_layers)
@@ -576,12 +580,34 @@ class ChessBotModel(ChessBotPreTrainedModel):
576
  targets = inp[1]
577
  true_values = inp[3]
578
  q_values = inp[4]
579
- loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
 
580
  z = torch.argmax(true_values, dim=-1)
581
- loss_value = F.cross_entropy(value_h.view(-1, value_h.size(-1)), z.view(-1), ignore_index=3)
582
- value_h_q = torch.softmax(value_h_q, dim=-1)
583
- loss_q = F.mse_loss(value_h_q.view(-1, value_h_q.size(-1)), q_values.view(-1, 3))
584
- return policy, value_h, loss_policy, loss_value, loss_q, targets, z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
  return BaseModelOutput(
587
  last_hidden_state=x,
@@ -590,80 +616,43 @@ class ChessBotModel(ChessBotPreTrainedModel):
590
  ), policy, value_h, value_h_q
591
 
592
  def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
593
- """
594
- Get a move from FEN string without thinking
595
- """
596
  board = chess.Board(fen)
597
- legal_moves = [move.uci() if move.uci() in policy_index else move.uci()[:-1] for move in board.legal_moves]
598
- if not legal_moves:
599
- return None
600
-
601
- # Convert FEN to tensor
602
- fen_tensor = fen_to_tensor(fen)
603
- fen_tensor = torch.from_numpy(fen_tensor).float().to(device)
604
- fen_tensor = fen_tensor.unsqueeze(0).unsqueeze(0) # Add batch and sequence dimensions
605
-
606
- # Get model prediction
607
- with torch.no_grad():
608
- _, policy, _, _ = self.forward(fen_tensor)
609
- policy = policy.squeeze(0).squeeze(0) # Remove batch and sequence dimensions
610
 
611
- if T == 0:
612
- if force_legal:
613
- # Find the move with the highest policy value that is legal
614
- legal_moves_mask = - torch.ones_like(policy) * 999
615
- for move in legal_moves:
616
- legal_moves_mask[policy_index.index(move)] = 0
617
- policy = legal_moves_mask + policy
618
- return policy_index[torch.argmax(policy).item()]
619
  else:
620
- max_policy_index = torch.argmax(policy).item()
621
- max_policy_move = policy_index[max_policy_index]
622
- return max_policy_move
623
-
624
- # Apply temperature
625
- if T > 0:
626
- policy = policy / T
627
-
628
- # Convert to probabilities
629
- probs = F.softmax(policy, dim=-1)
630
-
631
- # Map to legal moves
632
- legal_move_probs = {}
633
- for move in legal_moves:
634
- idx = policy_index.index(move)
635
- legal_move_probs[move] = probs[idx].item()
636
-
637
- # Select move based on probabilities
638
- if return_probs:
639
- return legal_move_probs
640
 
 
641
  if force_legal:
642
- # Only consider legal moves
643
- moves = list(legal_move_probs.keys())
644
- move_probs = list(legal_move_probs.values())
645
 
646
- # Normalize probabilities
647
- total_prob = sum(move_probs)
648
- move_probs = [p / total_prob for p in move_probs]
649
- selected_move = np.random.choice(moves, p=move_probs)
650
  else:
651
- # Consider all moves in policy
652
- selected_move = policy_index[torch.multinomial(probs, 1).item()]
653
-
654
- return selected_move
655
 
 
 
 
656
  def get_position_value(self, fen, device="cuda"):
657
- """
658
- Get the value evaluation for a given FEN position.
659
- Returns the value vector [black_win_prob, draw_prob, white_win_prob]
660
- """
661
  x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
662
  x = x.view(1, 1, 8, 8, 19)
663
 
664
- # Forward pass through the model to get value
665
  with torch.no_grad():
666
- # We need to run through the model layers to get to value_head
667
  b, seq_len, _, _, emb = x.size()
668
  x_processed = x.view(b * seq_len, 64, emb)
669
  x_processed = self.linear1(x_processed)
@@ -677,34 +666,22 @@ class ChessBotModel(ChessBotPreTrainedModel):
677
 
678
  value_logits = self.value_head_q(x_processed)
679
  value_logits = value_logits.view(b, seq_len, 3)
680
- value_logits = torch.softmax(value_logits, dim=-1)
681
-
682
- return value_logits.squeeze() # Remove batch and sequence dimensions
683
 
684
  def get_batch_position_values(self, fens, device="cuda"):
685
- """
686
- Get the value evaluation for a batch of FEN positions efficiently.
687
- Args:
688
- fens: List of FEN strings
689
- device: Device to run computations on
690
- Returns:
691
- value_probs: Tensor of shape [batch_size, 3] with [black_win_prob, draw_prob, white_win_prob] for each position
692
- """
693
  if len(fens) == 0:
694
  return torch.empty(0, 3, device=device)
695
 
696
- # Convert all FENs to tensors and stack them
697
  position_tensors = []
698
  for fen in fens:
699
  x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
700
  position_tensors.append(x)
701
 
702
- # Stack to create batch: [batch_size, 8, 8, 19]
703
  batch_x = torch.stack(position_tensors, dim=0)
704
- # Reshape to [batch_size, 1, 8, 8, 19] for the model
705
  batch_x = batch_x.unsqueeze(1)
706
 
707
- # Forward pass through the model to get values
708
  with torch.no_grad():
709
  b, seq_len, _, _, emb = batch_x.size()
710
  x_processed = batch_x.view(b * seq_len, 64, emb)
@@ -720,90 +697,75 @@ class ChessBotModel(ChessBotPreTrainedModel):
720
  value_logits = self.value_head_q(x_processed)
721
  value_logits = value_logits.view(b, seq_len, 3)
722
  value_logits = torch.softmax(value_logits, dim=-1)
723
- return value_logits.squeeze(1) # Remove sequence dimension, keep batch dimension
724
 
725
  def calculate_move_values(self, fen, device="cuda"):
726
- """
727
- Calculate the value for each legal move from the given position efficiently using batching.
728
- For white to move, value = white_win_prob - black_win_prob
729
- For black to move, value = black_win_prob - white_win_prob
730
- """
731
  board = chess.Board()
732
  board.set_fen(fen)
733
 
734
- # Determine whose turn it is
735
  is_white_turn = board.turn == chess.WHITE
736
 
737
  legal_moves = list(board.legal_moves)
738
  if len(legal_moves) == 0:
739
  return [], torch.empty(0, device=device)
740
 
741
- # Get all resulting FENs after each move
742
  resulting_fens = []
743
  for move in legal_moves:
744
  board.push(move)
745
  resulting_fens.append(board.fen())
746
  board.pop()
747
 
748
- # Batch process all positions in a single inference
749
  batch_value_q = self.get_batch_position_values(resulting_fens, device)
750
 
751
  # Calculate values from the current player's perspective
752
- # batch_value_probs[:, 0] = black_win_prob, [:, 1] = draw_prob, [:, 2] = white_win_prob
753
  batch_value_q = batch_value_q[:,2]-batch_value_q[:,0]
754
  if is_white_turn:
755
- # White's perspective: white_win_prob - black_win_prob
756
  player_values = batch_value_q
757
  else:
758
- # Black's perspective: black_win_prob - white_win_prob
759
  player_values = -batch_value_q
760
 
761
  return legal_moves, player_values
762
 
763
- def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False):
764
- """
765
- Determine the best move based on the value of resulting positions using efficient batching.
766
-
767
- Args:
768
- fen: FEN string of the position (works for both white and black to move)
769
- T: Temperature for sampling (T=0 for greedy, T>0 for stochastic)
770
- device: Device to run computations on
771
- return_probs: Whether to return the probability distribution
772
 
773
- Returns:
774
- move: UCI string of the selected move
775
- probs (optional): probability distribution over moves if return_probs=True
776
- """
 
 
 
 
 
777
  legal_moves, move_values = self.calculate_move_values(fen, device)
778
 
779
  if len(legal_moves) == 0:
780
  raise ValueError("No legal moves available")
781
 
782
  if T == 0:
783
- # Greedy selection - choose move with highest value
784
  best_idx = torch.argmax(move_values)
785
  selected_move = legal_moves[best_idx]
786
  else:
787
- # Stochastic selection based on move values
788
- # Convert values to probabilities using softmax with temperature
789
  probs = F.softmax(move_values / T, dim=0)
790
-
791
- # Sample according to probabilities
792
  sampled_idx = torch.multinomial(probs, num_samples=1)
793
  selected_move = legal_moves[sampled_idx.item()]
794
 
795
- # Convert chess.Move to UCI string
796
  move_uci = selected_move.uci()
797
 
798
  if return_probs:
799
  if T == 0:
800
- # Create one-hot distribution for greedy case
801
  probs = torch.zeros_like(move_values)
802
  probs[best_idx] = 1.0
803
  else:
804
  probs = F.softmax(move_values / T, dim=0)
805
 
806
- # Create dictionary with move strings as keys
807
  move_dict = {}
808
  for i, move in enumerate(legal_moves):
809
  move_dict[move.uci()] = probs[i].item()
 
1
  """
2
+ Updated HuggingFace Compatible ChessBot Chess Model
3
 
4
+ This file contains the updated architecture with d_ff=1024 and new weights
 
 
 
 
 
 
 
5
  """
6
 
7
  import torch
 
13
  from transformers.modeling_outputs import BaseModelOutput
14
  from typing import Optional, Tuple
15
  import math
16
+ import sys
17
+ import os
18
+
19
+ # Add current directory to path for imports
20
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
21
+
22
+ # Import attention mechanism
23
+ try:
24
+ from .attn import RelativeMultiHeadAttention2
25
+ except:
26
+ try:
27
+ from attn import RelativeMultiHeadAttention2
28
+ except:
29
+ # Fallback attention implementation
30
+ class RelativeMultiHeadAttention2(nn.Module):
31
+ def __init__(self, d_model: int = 512, num_heads: int = 8, dropout_p: float = 0.1):
32
+ super().__init__()
33
+ assert d_model % num_heads == 0
34
+ self.d_model = d_model
35
+ self.num_heads = num_heads
36
+ self.d_head = d_model // num_heads
37
+ self.sqrt_dim = math.sqrt(d_model)
38
+
39
+ self.query_proj = nn.Linear(d_model, d_model)
40
+ self.key_proj = nn.Linear(d_model, d_model)
41
+ self.value_proj = nn.Linear(d_model, d_model)
42
+ self.pos_proj = nn.Linear(d_model, d_model)
43
+ self.out_proj = nn.Linear(d_model, d_model)
44
+
45
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
46
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
47
+ torch.nn.init.xavier_uniform_(self.u_bias)
48
+ torch.nn.init.xavier_uniform_(self.v_bias)
49
+ self.dropout = nn.Dropout(dropout_p)
50
+
51
+ def forward(self, query, key, value, pos_embedding, mask=None):
52
+ batch_size = value.size(0)
53
+
54
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
55
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
56
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
57
+
58
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
59
+
60
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
61
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
62
+ pos_score = self._compute_relative_positional_encoding(pos_score)
63
+
64
+ score = (content_score + pos_score) / self.sqrt_dim
65
+
66
+ if mask is not None:
67
+ mask = mask.unsqueeze(1)
68
+ score.masked_fill_(mask, -1e9)
69
+
70
+ attn = F.softmax(score, -1)
71
+ attn = self.dropout(attn)
72
+
73
+ context = torch.matmul(attn, value).transpose(1, 2)
74
+ context = context.contiguous().view(batch_size, -1, self.d_model)
75
+
76
+ return self.out_proj(context)
77
+
78
+ def _compute_relative_positional_encoding(self, pos_score):
79
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
80
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
81
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
82
+
83
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
84
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
85
+
86
+ return pos_score
87
+
88
+ # Import utility functions
89
+ try:
90
+ from utils.vocab import policy_index
91
+ from utils.fen_encoder import fen_to_tensor
92
+ except:
93
+ # Fallback implementations
94
+ def fen_to_tensor(fen: str):
95
+ """Convert FEN string to tensor representation for the model."""
96
+ board = chess.Board(fen)
97
+ tensor = np.zeros((8, 8, 19), dtype=np.float32)
98
+
99
+ # Piece mapping
100
+ piece_map = {
101
+ 'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5, # White pieces
102
+ 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11 # Black pieces
103
+ }
104
+
105
+ # Fill piece positions
106
+ for square in chess.SQUARES:
107
+ piece = board.piece_at(square)
108
+ if piece:
109
+ row = 7 - (square // 8) # Flip vertically for proper orientation
110
+ col = square % 8
111
+ tensor[row, col, piece_map[piece.symbol()]] = 1.0
112
+
113
+ # Add metadata channels
114
+ # Channel 12: White to move
115
+ if board.turn == chess.WHITE:
116
+ tensor[:, :, 12] = 1.0
117
+
118
+ # Channel 13: Black to move
119
+ if board.turn == chess.BLACK:
120
+ tensor[:, :, 13] = 1.0
121
+
122
+ # Castling rights
123
+ if board.has_kingside_castling_rights(chess.WHITE):
124
+ tensor[:, :, 14] = 1.0
125
+ if board.has_queenside_castling_rights(chess.WHITE):
126
+ tensor[:, :, 15] = 1.0
127
+ if board.has_kingside_castling_rights(chess.BLACK):
128
+ tensor[:, :, 16] = 1.0
129
+ if board.has_queenside_castling_rights(chess.BLACK):
130
+ tensor[:, :, 17] = 1.0
131
+
132
+ # En passant
133
+ if board.ep_square is not None:
134
+ ep_row = 7 - (board.ep_square // 8)
135
+ ep_col = board.ep_square % 8
136
+ tensor[ep_row, ep_col, 18] = 1.0
137
+
138
+ return tensor
139
+
140
+ # Complete policy index with all 1929 moves
141
+ policy_index = [
142
+ "a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2", "a1b2",
143
+ "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5", "a1e5", "a1a6",
144
+ "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1", "b1c1", "b1d1", "b1e1",
145
+ "b1f1", "b1g1", "b1h1", "b1a2", "b1b2", "b1c2", "b1d2", "b1a3", "b1b3",
146
+ "b1c3", "b1d3", "b1b4", "b1e4", "b1b5", "b1f5", "b1b6", "b1g6", "b1b7",
147
+ "b1h7", "b1b8", "c1a1", "c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1",
148
+ "c1a2", "c1b2", "c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3",
149
+ "c1e3", "c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
150
+ "d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2", "d1c2",
151
+ "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3", "d1f3", "d1a4",
152
+ "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7", "d1d8", "e1a1", "e1b1",
153
+ "e1c1", "e1d1", "e1f1", "e1g1", "e1h1", "e1c2", "e1d2", "e1e2", "e1f2",
154
+ "e1g2", "e1c3", "e1d3", "e1e3", "e1f3", "e1g3", "e1b4", "e1e4", "e1h4",
155
+ "e1a5", "e1e5", "e1e6", "e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1",
156
+ "f1e1", "f1g1", "f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3",
157
+ "f1e3", "f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
158
+ "f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1", "g1f1",
159
+ "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3", "g1g3", "g1h3",
160
+ "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6", "g1a7", "g1g7", "g1g8",
161
+ "h1a1", "h1b1", "h1c1", "h1d1", "h1e1", "h1f1", "h1g1", "h1f2", "h1g2",
162
+ "h1h2", "h1f3", "h1g3", "h1h3", "h1e4", "h1h4", "h1d5", "h1h5", "h1c6",
163
+ "h1h6", "h1b7", "h1h7", "h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2",
164
+ "a2c2", "a2d2", "a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3",
165
+ "a2a4", "a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
166
+ "a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2", "b2d2",
167
+ "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3", "b2d3", "b2a4",
168
+ "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6", "b2f6", "b2b7", "b2g7",
169
+ "b2b8", "b2h8", "c2a1", "c2b1", "c2c1", "c2d1", "c2e1", "c2a2", "c2b2",
170
+ "c2d2", "c2e2", "c2f2", "c2g2", "c2h2", "c2a3", "c2b3", "c2c3", "c2d3",
171
+ "c2e3", "c2a4", "c2b4", "c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6",
172
+ "c2g6", "c2c7", "c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1",
173
+ "d2a2", "d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
174
+ "d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4", "d2a5",
175
+ "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1", "e2d1", "e2e1",
176
+ "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2", "e2f2", "e2g2", "e2h2",
177
+ "e2c3", "e2d3", "e2e3", "e2f3", "e2g3", "e2c4", "e2d4", "e2e4", "e2f4",
178
+ "e2g4", "e2b5", "e2e5", "e2h5", "e2a6", "e2e6", "e2e7", "e2e8", "f2d1",
179
+ "f2e1", "f2f1", "f2g1", "f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2",
180
+ "f2g2", "f2h2", "f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4",
181
+ "f2f4", "f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
182
+ "f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2", "g2d2",
183
+ "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3", "g2e4", "g2f4",
184
+ "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6", "g2b7", "g2g7", "g2a8",
185
+ "g2g8", "h2f1", "h2g1", "h2h1", "h2a2", "h2b2", "h2c2", "h2d2", "h2e2",
186
+ "h2f2", "h2g2", "h2f3", "h2g3", "h2h3", "h2f4", "h2g4", "h2h4", "h2e5",
187
+ "h2h5", "h2d6", "h2h6", "h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1",
188
+ "a3c1", "a3a2", "a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3",
189
+ "a3g3", "a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
190
+ "a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1", "b3d1",
191
+ "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3", "b3e3", "b3f3",
192
+ "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4", "b3a5", "b3b5", "b3c5",
193
+ "b3d5", "b3b6", "b3e6", "b3b7", "b3f7", "b3b8", "b3g8", "c3a1", "c3b1",
194
+ "c3c1", "c3d1", "c3e1", "c3a2", "c3b2", "c3c2", "c3d2", "c3e2", "c3a3",
195
+ "c3b3", "c3d3", "c3e3", "c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4",
196
+ "c3d4", "c3e4", "c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6",
197
+ "c3c7", "c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
198
+ "d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3", "d3e3",
199
+ "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4", "d3f4", "d3b5",
200
+ "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6", "d3g6", "d3d7", "d3h7",
201
+ "d3d8", "e3c1", "e3d1", "e3e1", "e3f1", "e3g1", "e3c2", "e3d2", "e3e2",
202
+ "e3f2", "e3g2", "e3a3", "e3b3", "e3c3", "e3d3", "e3f3", "e3g3", "e3h3",
203
+ "e3c4", "e3d4", "e3e4", "e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5",
204
+ "e3g5", "e3b6", "e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1",
205
+ "f3f1", "f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
206
+ "f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4", "f3f4",
207
+ "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5", "f3c6", "f3f6",
208
+ "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1", "g3g1", "g3h1", "g3e2",
209
+ "g3f2", "g3g2", "g3h2", "g3a3", "g3b3", "g3c3", "g3d3", "g3e3", "g3f3",
210
+ "g3h3", "g3e4", "g3f4", "g3g4", "g3h4", "g3e5", "g3f5", "g3g5", "g3h5",
211
+ "g3d6", "g3g6", "g3c7", "g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1",
212
+ "h3f2", "h3g2", "h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3",
213
+ "h3g3", "h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
214
+ "h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2", "a4c2",
215
+ "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4", "a4f4", "a4g4",
216
+ "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6", "a4c6", "a4a7", "a4d7",
217
+ "a4a8", "a4e8", "b4b1", "b4e1", "b4a2", "b4b2", "b4c2", "b4d2", "b4a3",
218
+ "b4b3", "b4c3", "b4d3", "b4a4", "b4c4", "b4d4", "b4e4", "b4f4", "b4g4",
219
+ "b4h4", "b4a5", "b4b5", "b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6",
220
+ "b4b7", "b4e7", "b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2",
221
+ "c4d2", "c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
222
+ "c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5", "c4d5",
223
+ "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7", "c4f7", "c4c8",
224
+ "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2", "d4d2", "d4e2", "d4f2",
225
+ "d4b3", "d4c3", "d4d3", "d4e3", "d4f3", "d4a4", "d4b4", "d4c4", "d4e4",
226
+ "d4f4", "d4g4", "d4h4", "d4b5", "d4c5", "d4d5", "d4e5", "d4f5", "d4b6",
227
+ "d4c6", "d4d6", "d4e6", "d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8",
228
+ "e4b1", "e4e1", "e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3",
229
+ "e4d3", "e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
230
+ "e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6", "e4d6",
231
+ "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8", "e4e8", "f4c1",
232
+ "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2", "f4d3", "f4e3", "f4f3",
233
+ "f4g3", "f4h3", "f4a4", "f4b4", "f4c4", "f4d4", "f4e4", "f4g4", "f4h4",
234
+ "f4d5", "f4e5", "f4f5", "f4g5", "f4h5", "f4d6", "f4e6", "f4f6", "f4g6",
235
+ "f4h6", "f4c7", "f4f7", "f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2",
236
+ "g4g2", "g4h2", "g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4",
237
+ "g4d4", "g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
238
+ "g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1", "h4h1",
239
+ "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4", "h4b4", "h4c4",
240
+ "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5", "h4h5", "h4f6", "h4g6",
241
+ "h4h6", "h4e7", "h4h7", "h4d8", "h4h8", "a5a1", "a5e1", "a5a2", "a5d2",
242
+ "a5a3", "a5b3", "a5c3", "a5a4", "a5b4", "a5c4", "a5b5", "a5c5", "a5d5",
243
+ "a5e5", "a5f5", "a5g5", "a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7",
244
+ "a5c7", "a5a8", "a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3",
245
+ "b5c3", "b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
246
+ "b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6", "b5a7",
247
+ "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1", "c5c2", "c5f2",
248
+ "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4", "c5b4", "c5c4", "c5d4",
249
+ "c5e4", "c5a5", "c5b5", "c5d5", "c5e5", "c5f5", "c5g5", "c5h5", "c5a6",
250
+ "c5b6", "c5c6", "c5d6", "c5e6", "c5a7", "c5b7", "c5c7", "c5d7", "c5e7",
251
+ "c5c8", "c5f8", "d5d1", "d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3",
252
+ "d5d3", "d5e3", "d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5",
253
+ "d5b5", "d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
254
+ "d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8", "d5d8",
255
+ "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3", "e5d3", "e5e3",
256
+ "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4", "e5g4", "e5a5", "e5b5",
257
+ "e5c5", "e5d5", "e5f5", "e5g5", "e5h5", "e5c6", "e5d6", "e5e6", "e5f6",
258
+ "e5g6", "e5c7", "e5d7", "e5e7", "e5f7", "e5g7", "e5b8", "e5e8", "e5h8",
259
+ "f5b1", "f5f1", "f5c2", "f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3",
260
+ "f5d4", "f5e4", "f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5",
261
+ "f5e5", "f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
262
+ "f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1", "g5d2",
263
+ "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4", "g5g4", "g5h4",
264
+ "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5", "g5h5", "g5e6", "g5f6",
265
+ "g5g6", "g5h6", "g5e7", "g5f7", "g5g7", "g5h7", "g5d8", "g5g8", "h5d1",
266
+ "h5h1", "h5e2", "h5h2", "h5f3", "h5g3", "h5h3", "h5f4", "h5g4", "h5h4",
267
+ "h5a5", "h5b5", "h5c5", "h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6",
268
+ "h5h6", "h5f7", "h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2",
269
+ "a6e2", "a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
270
+ "a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7", "a6b7",
271
+ "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2", "b6f2", "b6b3",
272
+ "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5", "b6b5", "b6c5", "b6d5",
273
+ "b6a6", "b6c6", "b6d6", "b6e6", "b6f6", "b6g6", "b6h6", "b6a7", "b6b7",
274
+ "b6c7", "b6d7", "b6a8", "b6b8", "b6c8", "b6d8", "c6c1", "c6h1", "c6c2",
275
+ "c6g2", "c6c3", "c6f3", "c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5",
276
+ "c6b5", "c6c5", "c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6",
277
+ "c6g6", "c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
278
+ "c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3", "d6g3",
279
+ "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5", "d6d5", "d6e5",
280
+ "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6", "d6g6", "d6h6", "d6b7",
281
+ "d6c7", "d6d7", "d6e7", "d6f7", "d6b8", "d6c8", "d6d8", "d6e8", "d6f8",
282
+ "e6e1", "e6a2", "e6e2", "e6b3", "e6e3", "e6h3", "e6c4", "e6d4", "e6e4",
283
+ "e6f4", "e6g4", "e6c5", "e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6",
284
+ "e6c6", "e6d6", "e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7",
285
+ "e6g7", "e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
286
+ "f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4", "f6d5",
287
+ "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6", "f6d6", "f6e6",
288
+ "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7", "f6h7", "f6d8", "f6e8",
289
+ "f6f8", "f6g8", "f6h8", "g6b1", "g6g1", "g6c2", "g6g2", "g6d3", "g6g3",
290
+ "g6e4", "g6f4", "g6g4", "g6h4", "g6e5", "g6f5", "g6g5", "g6h5", "g6a6",
291
+ "g6b6", "g6c6", "g6d6", "g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7",
292
+ "g6h7", "g6e8", "g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2",
293
+ "h6e3", "h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
294
+ "h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7", "h6h7",
295
+ "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2", "a7a3", "a7e3",
296
+ "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6", "a7b6", "a7c6", "a7b7",
297
+ "a7c7", "a7d7", "a7e7", "a7f7", "a7g7", "a7h7", "a7a8", "a7b8", "a7c8",
298
+ "b7b1", "b7h1", "b7b2", "b7g2", "b7b3", "b7f3", "b7b4", "b7e4", "b7a5",
299
+ "b7b5", "b7c5", "b7d5", "b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7",
300
+ "b7d7", "b7e7", "b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8",
301
+ "c7c1", "c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
302
+ "c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6", "c7a7",
303
+ "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8", "c7b8", "c7c8",
304
+ "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3", "d7a4", "d7d4", "d7g4",
305
+ "d7b5", "d7c5", "d7d5", "d7e5", "d7f5", "d7b6", "d7c6", "d7d6", "d7e6",
306
+ "d7f6", "d7a7", "d7b7", "d7c7", "d7e7", "d7f7", "d7g7", "d7h7", "d7b8",
307
+ "d7c8", "d7d8", "d7e8", "d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4",
308
+ "e7e4", "e7h4", "e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6",
309
+ "e7e6", "e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
310
+ "e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2", "f7f2",
311
+ "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5", "f7g5", "f7h5",
312
+ "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7", "f7b7", "f7c7", "f7d7",
313
+ "f7e7", "f7g7", "f7h7", "f7d8", "f7e8", "f7f8", "f7g8", "f7h8", "g7a1",
314
+ "g7g1", "g7b2", "g7g2", "g7c3", "g7g3", "g7d4", "g7g4", "g7e5", "g7f5",
315
+ "g7g5", "g7h5", "g7e6", "g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7",
316
+ "g7d7", "g7e7", "g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1",
317
+ "h7h1", "h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
318
+ "h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7", "h7e7",
319
+ "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1", "a8a2", "a8g2",
320
+ "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5", "a8a6", "a8b6", "a8c6",
321
+ "a8a7", "a8b7", "a8c7", "a8b8", "a8c8", "a8d8", "a8e8", "a8f8", "a8g8",
322
+ "a8h8", "b8b1", "b8b2", "b8h2", "b8b3", "b8g3", "b8b4", "b8f4", "b8b5",
323
+ "b8e5", "b8a6", "b8b6", "b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7",
324
+ "b8a8", "b8c8", "b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2",
325
+ "c8c3", "c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
326
+ "c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8", "c8b8",
327
+ "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2", "d8d3", "d8d4",
328
+ "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6", "d8d6", "d8e6", "d8f6",
329
+ "d8b7", "d8c7", "d8d7", "d8e7", "d8f7", "d8a8", "d8b8", "d8c8", "d8e8",
330
+ "d8f8", "d8g8", "d8h8", "e8e1", "e8e2", "e8e3", "e8a4", "e8e4", "e8b5",
331
+ "e8e5", "e8h5", "e8c6", "e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7",
332
+ "e8e7", "e8f7", "e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8",
333
+ "e8h8", "f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
334
+ "f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7", "f8g7",
335
+ "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8", "f8h8", "g8g1",
336
+ "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4", "g8d5", "g8g5", "g8e6",
337
+ "g8f6", "g8g6", "g8h6", "g8e7", "g8f7", "g8g7", "g8h7", "g8a8", "g8b8",
338
+ "g8c8", "g8d8", "g8e8", "g8f8", "g8h8", "h8a1", "h8h1", "h8b2", "h8h2",
339
+ "h8c3", "h8h3", "h8d4", "h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6",
340
+ "h8f7", "h8g7", "h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8",
341
+ "h8g8", "a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q",
342
+ "b7a8r", "b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b",
343
+ "c7b8q", "c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r",
344
+ "c7d8b", "d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q",
345
+ "d7e8r", "d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b",
346
+ "e7f8q", "e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r",
347
+ "f7f8b", "f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q",
348
+ "g7g8r", "g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b",
349
+ "h7h8q", "h7h8r", "h7h8b", #add the promotions for black
350
+ "a2a1q","a2a1r","a2a1b","a2b1q","a2b1r","a2b1b",
351
+ "b2a1q","b2a1r","b2a1b","b2b1q","b2b1r","b2b1b","b2c1q","b2c1r","b2c1b",
352
+ "c2b1q","c2b1r","c2b1b","c2c1q","c2c1r","c2c1b","c2d1q","c2d1r","c2d1b",
353
+ "d2c1q","d2c1r","d2c1b","d2d1q","d2d1r","d2d1b","d2e1q","d2e1r","d2e1b",
354
+ "e2d1q","e2d1r","e2d1b","e2e1q","e2e1r","e2e1b","e2f1q","e2f1r","e2f1b",
355
+ "f2e1q","f2e1r","f2e1b","f2f1q","f2f1r","f2f1b","f2g1q","f2g1r","f2g1b",
356
+ "g2f1q","g2f1r","g2f1b","g2g1q","g2g1r","g2g1b","g2h1q","g2h1r","g2h1b",
357
+ "h2g1q","h2g1r","h2g1b","h2h1q","h2h1r","h2h1b",#add special tokens
358
+ "<thinking>","</thinking>","end_variation","end","padding_token"
359
+ ]
360
 
361
 
362
  # Configuration class
 
371
  self,
372
  num_layers: int = 10,
373
  d_model: int = 512,
374
+ d_ff: int = 1024, # Updated to match new architecture
375
  num_heads: int = 8,
376
  vocab_size: int = 1929,
377
  max_position_embeddings: int = 64,
 
386
  self.max_position_embeddings = max_position_embeddings
387
 
388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  # Model components
390
  class MaGating(nn.Module):
391
  def __init__(self, d_model):
 
413
  x = self.norm1(x)
414
 
415
  y = self.ff1(x)
 
416
  y = self.gelu(y)
417
+ y = self.ff2(y)
418
  y = y + x
419
  y = self.norm2(y)
420
 
 
499
 
500
  class ChessBotModel(ChessBotPreTrainedModel):
501
  """
502
+ Updated HuggingFace compatible ChessBot Chess model with d_ff=1024
503
  """
504
 
505
  def __init__(self, config):
506
  super().__init__(config)
507
  self.config = config
508
 
509
+ # Initialize exactly like the updated BT4 model
510
  self.is_thinking_model = False
511
  self.d_model = config.d_model
512
  self.num_layers = config.num_layers
513
 
514
+ # Model layers - same as updated model
515
  self.layers = nn.ModuleList([
516
  EncoderLayer(config.d_model, config.d_ff, config.num_heads)
517
  for _ in range(config.num_layers)
 
580
  targets = inp[1]
581
  true_values = inp[3]
582
  q_values = inp[4]
583
+ true_values = q_values
584
+
585
  z = torch.argmax(true_values, dim=-1)
586
+ q = torch.argmax(q_values, dim=-1)
587
+ value_h_q_softmax = torch.softmax(value_h_q, dim=-1)
588
+
589
+ # Always compute policy loss
590
+ loss_policy = F.cross_entropy(policy.view(-1, policy.size(-1)), targets.view(-1), ignore_index=1928)
591
+
592
+ # Create mask for samples where true_values/q_values is not [0,0,0]
593
+ valid_mask = (true_values.sum(dim=-1) != 0) & (q_values.sum(dim=-1) != 0)
594
+
595
+ # Only compute value losses if we have valid samples
596
+ if valid_mask.any():
597
+ # Filter to only valid samples
598
+ valid_value_h = value_h[valid_mask]
599
+ valid_value_h_q = value_h_q_softmax[valid_mask]
600
+ valid_z = z[valid_mask]
601
+ valid_q_values = q_values[valid_mask]
602
+
603
+ loss_value = F.cross_entropy(valid_value_h.view(-1, valid_value_h.size(-1)), valid_z.view(-1))
604
+ loss_q = F.mse_loss(valid_value_h_q.view(-1, valid_value_h_q.size(-1)), valid_q_values.view(-1, 3))
605
+ else:
606
+ # No valid samples, set losses to zero
607
+ loss_value = torch.tensor(0.0, device=value_h.device, requires_grad=True)
608
+ loss_q = torch.tensor(0.0, device=value_h_q.device, requires_grad=True)
609
+
610
+ return policy, value_h, value_h_q, loss_policy, loss_value, loss_q, targets, z, q
611
 
612
  return BaseModelOutput(
613
  last_hidden_state=x,
 
616
  ), policy, value_h, value_h_q
617
 
618
  def get_move_from_fen_no_thinking(self, fen, T=1, device="cuda", force_legal=True, return_probs=False):
619
+ """Get a move from FEN string without thinking"""
 
 
620
  board = chess.Board(fen)
621
+ x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
622
+ x = x.view(1, 1, 8, 8, 19)
 
 
 
 
 
 
 
 
 
 
 
623
 
624
+ _, logits, _, _ = self.forward(x)
625
+ logits = logits.view(-1, 1929)
626
+ legal_move_mask = torch.zeros((1, 1929), device=device)
627
+ for legal_move in board.legal_moves:
628
+ if legal_move.uci()[-1] == 'n':
629
+ legal_move_uci = legal_move.uci()[:-1]
 
 
630
  else:
631
+ legal_move_uci = legal_move.uci()
632
+ if legal_move_uci in policy_index:
633
+ legal_move_mask[0][policy_index.index(legal_move_uci)] = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
+ # Set all illegal moves to -inf
636
  if force_legal:
637
+ logits = logits + (1-legal_move_mask) * -999
 
 
638
 
639
+ if T == 0:
640
+ sampled = torch.argmax(logits, dim=-1, keepdim=True)
 
 
641
  else:
642
+ probs = F.softmax(logits/T, dim=-1)
643
+ sampled = torch.multinomial(probs, num_samples=1)
644
+ if return_probs:
645
+ return sampled, probs
646
 
647
+ move = policy_index[sampled.item()]
648
+ return move
649
+
650
  def get_position_value(self, fen, device="cuda"):
651
+ """Get the value evaluation for a given FEN position."""
 
 
 
652
  x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
653
  x = x.view(1, 1, 8, 8, 19)
654
 
 
655
  with torch.no_grad():
 
656
  b, seq_len, _, _, emb = x.size()
657
  x_processed = x.view(b * seq_len, 64, emb)
658
  x_processed = self.linear1(x_processed)
 
666
 
667
  value_logits = self.value_head_q(x_processed)
668
  value_logits = value_logits.view(b, seq_len, 3)
669
+ value = torch.softmax(value_logits, dim=-1)
670
+ return value.squeeze()
 
671
 
672
  def get_batch_position_values(self, fens, device="cuda"):
673
+ """Get the value evaluation for a batch of FEN positions efficiently."""
 
 
 
 
 
 
 
674
  if len(fens) == 0:
675
  return torch.empty(0, 3, device=device)
676
 
 
677
  position_tensors = []
678
  for fen in fens:
679
  x = torch.from_numpy(fen_to_tensor(fen)).to(device).to(torch.float32)
680
  position_tensors.append(x)
681
 
 
682
  batch_x = torch.stack(position_tensors, dim=0)
 
683
  batch_x = batch_x.unsqueeze(1)
684
 
 
685
  with torch.no_grad():
686
  b, seq_len, _, _, emb = batch_x.size()
687
  x_processed = batch_x.view(b * seq_len, 64, emb)
 
697
  value_logits = self.value_head_q(x_processed)
698
  value_logits = value_logits.view(b, seq_len, 3)
699
  value_logits = torch.softmax(value_logits, dim=-1)
700
+ return value_logits.squeeze(1)
701
 
702
  def calculate_move_values(self, fen, device="cuda"):
703
+ """Calculate the value for each legal move from the given position efficiently using batching."""
 
 
 
 
704
  board = chess.Board()
705
  board.set_fen(fen)
706
 
 
707
  is_white_turn = board.turn == chess.WHITE
708
 
709
  legal_moves = list(board.legal_moves)
710
  if len(legal_moves) == 0:
711
  return [], torch.empty(0, device=device)
712
 
 
713
  resulting_fens = []
714
  for move in legal_moves:
715
  board.push(move)
716
  resulting_fens.append(board.fen())
717
  board.pop()
718
 
 
719
  batch_value_q = self.get_batch_position_values(resulting_fens, device)
720
 
721
  # Calculate values from the current player's perspective
 
722
  batch_value_q = batch_value_q[:,2]-batch_value_q[:,0]
723
  if is_white_turn:
 
724
  player_values = batch_value_q
725
  else:
 
726
  player_values = -batch_value_q
727
 
728
  return legal_moves, player_values
729
 
730
+ def get_best_move_value(self, fen, T=1, device="cuda", return_probs=False, to_fall_back_to_policy=False):
731
+ """Determine the best move based on the value of resulting positions using efficient batching."""
732
+ # Check if we should fall back to policy
733
+ if to_fall_back_to_policy:
734
+ value = self.get_position_value(fen, device)
735
+ board = chess.Board()
736
+ board.set_fen(fen)
 
 
737
 
738
+ is_white_turn = board.turn == chess.WHITE
739
+ if is_white_turn:
740
+ value = value[2]-value[0]
741
+ else:
742
+ value = value[0]-value[2]
743
+
744
+ if value > 0.9:
745
+ return self.get_move_from_fen_no_thinking(fen, T, device, force_legal=True, return_probs=return_probs)
746
+
747
  legal_moves, move_values = self.calculate_move_values(fen, device)
748
 
749
  if len(legal_moves) == 0:
750
  raise ValueError("No legal moves available")
751
 
752
  if T == 0:
 
753
  best_idx = torch.argmax(move_values)
754
  selected_move = legal_moves[best_idx]
755
  else:
 
 
756
  probs = F.softmax(move_values / T, dim=0)
 
 
757
  sampled_idx = torch.multinomial(probs, num_samples=1)
758
  selected_move = legal_moves[sampled_idx.item()]
759
 
 
760
  move_uci = selected_move.uci()
761
 
762
  if return_probs:
763
  if T == 0:
 
764
  probs = torch.zeros_like(move_values)
765
  probs[best_idx] = 1.0
766
  else:
767
  probs = F.softmax(move_values / T, dim=0)
768
 
 
769
  move_dict = {}
770
  for i, move in enumerate(legal_moves):
771
  move_dict[move.uci()] = probs[i].item()