Spaces:
Running
Running
import chess | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class O2Net(nn.Module): | |
def __init__(self): | |
super(O2Net, self).__init__() | |
# Input layer (updated to 1152 for 8x8x18 encoding) | |
self.input_fc = nn.Linear(1152, 1024) | |
# 10 deep residual blocks | |
self.res_blocks = nn.ModuleList([ | |
nn.Sequential( | |
nn.Linear(1024, 1024), | |
nn.BatchNorm1d(1024), | |
nn.ReLU(), | |
nn.Linear(1024, 1024), | |
nn.BatchNorm1d(1024) | |
) for _ in range(10) | |
]) | |
self.res_relu = nn.ReLU() | |
# Policy head | |
self.policy_fc1 = nn.Linear(1024, 512) | |
self.policy_fc2 = nn.Linear(512, 256) | |
self.policy_fc3 = nn.Linear(256, 4672) | |
# Value head | |
self.value_fc1 = nn.Linear(1024, 512) | |
self.value_fc2 = nn.Linear(512, 128) | |
self.value_fc3 = nn.Linear(128, 1) | |
def forward(self, x): | |
x = F.relu(self.input_fc(x)) | |
for block in self.res_blocks: | |
residual = x | |
out = block(x) | |
x = self.res_relu(out + residual) | |
# Policy head | |
p = F.relu(self.policy_fc1(x)) | |
p = F.relu(self.policy_fc2(p)) | |
policy = self.policy_fc3(p) | |
# Value head | |
v = F.relu(self.value_fc1(x)) | |
v = F.relu(self.value_fc2(v)) | |
value = torch.tanh(self.value_fc3(v)) | |
return policy, value | |
def board_to_tensor(board): | |
# Improved encoding: 8x8x18 planes (12 for pieces, 6 for state), flattened | |
# 12 planes: one for each piece type/color | |
# 6 planes: turn, castling rights (4), en passant | |
planes = np.zeros((18, 8, 8), dtype=np.float32) | |
piece_map = board.piece_map() | |
for square, piece in piece_map.items(): | |
plane = (piece.piece_type - 1) + (0 if piece.color == chess.WHITE else 6) | |
row, col = divmod(square, 8) | |
planes[plane, row, col] = 1 | |
# Turn plane | |
planes[12, :, :] = int(board.turn) | |
# Castling rights | |
planes[13, :, :] = int(board.has_kingside_castling_rights(chess.WHITE)) | |
planes[14, :, :] = int(board.has_queenside_castling_rights(chess.WHITE)) | |
planes[15, :, :] = int(board.has_kingside_castling_rights(chess.BLACK)) | |
planes[16, :, :] = int(board.has_queenside_castling_rights(chess.BLACK)) | |
# En passant | |
if board.ep_square is not None: | |
row, col = divmod(board.ep_square, 8) | |
planes[17, row, col] = 1 | |
return planes.flatten() | |
if __name__ == "__main__": | |
board = chess.Board() | |
net = O2Net() | |
x = torch.tensor(board_to_tensor(board)).unsqueeze(0) | |
policy, value = net(x) | |
print("Policy shape:", policy.shape) | |
print("Value:", value.item()) | |