File size: 2,825 Bytes
cc24eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import chess
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class O2Net(nn.Module):
    def __init__(self):
        super(O2Net, self).__init__()
        # Input layer (updated to 1152 for 8x8x18 encoding)
        self.input_fc = nn.Linear(1152, 1024)
        # 10 deep residual blocks
        self.res_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1024, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.BatchNorm1d(1024)
            ) for _ in range(10)
        ])
        self.res_relu = nn.ReLU()
        # Policy head
        self.policy_fc1 = nn.Linear(1024, 512)
        self.policy_fc2 = nn.Linear(512, 256)
        self.policy_fc3 = nn.Linear(256, 4672)
        # Value head
        self.value_fc1 = nn.Linear(1024, 512)
        self.value_fc2 = nn.Linear(512, 128)
        self.value_fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.input_fc(x))
        for block in self.res_blocks:
            residual = x
            out = block(x)
            x = self.res_relu(out + residual)
        # Policy head
        p = F.relu(self.policy_fc1(x))
        p = F.relu(self.policy_fc2(p))
        policy = self.policy_fc3(p)
        # Value head
        v = F.relu(self.value_fc1(x))
        v = F.relu(self.value_fc2(v))
        value = torch.tanh(self.value_fc3(v))
        return policy, value

def board_to_tensor(board):
    # Improved encoding: 8x8x18 planes (12 for pieces, 6 for state), flattened
    # 12 planes: one for each piece type/color
    # 6 planes: turn, castling rights (4), en passant
    planes = np.zeros((18, 8, 8), dtype=np.float32)
    piece_map = board.piece_map()
    for square, piece in piece_map.items():
        plane = (piece.piece_type - 1) + (0 if piece.color == chess.WHITE else 6)
        row, col = divmod(square, 8)
        planes[plane, row, col] = 1
    # Turn plane
    planes[12, :, :] = int(board.turn)
    # Castling rights
    planes[13, :, :] = int(board.has_kingside_castling_rights(chess.WHITE))
    planes[14, :, :] = int(board.has_queenside_castling_rights(chess.WHITE))
    planes[15, :, :] = int(board.has_kingside_castling_rights(chess.BLACK))
    planes[16, :, :] = int(board.has_queenside_castling_rights(chess.BLACK))
    # En passant
    if board.ep_square is not None:
        row, col = divmod(board.ep_square, 8)
        planes[17, row, col] = 1
    return planes.flatten()

if __name__ == "__main__":
    board = chess.Board()
    net = O2Net()
    x = torch.tensor(board_to_tensor(board)).unsqueeze(0)
    policy, value = net(x)
    print("Policy shape:", policy.shape)
    print("Value:", value.item())