Play-with-o2 / src /o2_model.py
FlameF0X's picture
Upload 3 files
cc24eeb verified
import chess
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class O2Net(nn.Module):
def __init__(self):
super(O2Net, self).__init__()
# Input layer (updated to 1152 for 8x8x18 encoding)
self.input_fc = nn.Linear(1152, 1024)
# 10 deep residual blocks
self.res_blocks = nn.ModuleList([
nn.Sequential(
nn.Linear(1024, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.BatchNorm1d(1024)
) for _ in range(10)
])
self.res_relu = nn.ReLU()
# Policy head
self.policy_fc1 = nn.Linear(1024, 512)
self.policy_fc2 = nn.Linear(512, 256)
self.policy_fc3 = nn.Linear(256, 4672)
# Value head
self.value_fc1 = nn.Linear(1024, 512)
self.value_fc2 = nn.Linear(512, 128)
self.value_fc3 = nn.Linear(128, 1)
def forward(self, x):
x = F.relu(self.input_fc(x))
for block in self.res_blocks:
residual = x
out = block(x)
x = self.res_relu(out + residual)
# Policy head
p = F.relu(self.policy_fc1(x))
p = F.relu(self.policy_fc2(p))
policy = self.policy_fc3(p)
# Value head
v = F.relu(self.value_fc1(x))
v = F.relu(self.value_fc2(v))
value = torch.tanh(self.value_fc3(v))
return policy, value
def board_to_tensor(board):
# Improved encoding: 8x8x18 planes (12 for pieces, 6 for state), flattened
# 12 planes: one for each piece type/color
# 6 planes: turn, castling rights (4), en passant
planes = np.zeros((18, 8, 8), dtype=np.float32)
piece_map = board.piece_map()
for square, piece in piece_map.items():
plane = (piece.piece_type - 1) + (0 if piece.color == chess.WHITE else 6)
row, col = divmod(square, 8)
planes[plane, row, col] = 1
# Turn plane
planes[12, :, :] = int(board.turn)
# Castling rights
planes[13, :, :] = int(board.has_kingside_castling_rights(chess.WHITE))
planes[14, :, :] = int(board.has_queenside_castling_rights(chess.WHITE))
planes[15, :, :] = int(board.has_kingside_castling_rights(chess.BLACK))
planes[16, :, :] = int(board.has_queenside_castling_rights(chess.BLACK))
# En passant
if board.ep_square is not None:
row, col = divmod(board.ep_square, 8)
planes[17, row, col] = 1
return planes.flatten()
if __name__ == "__main__":
board = chess.Board()
net = O2Net()
x = torch.tensor(board_to_tensor(board)).unsqueeze(0)
policy, value = net(x)
print("Policy shape:", policy.shape)
print("Value:", value.item())