dualist / model.py
brandonlanexyz's picture
Initial upload of Dualist Othello AI (Iteration 652)
cf2aacd verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
out = F.relu(out)
return out
class OthelloNet(nn.Module):
def __init__(self, num_res_blocks=10, num_channels=256):
super(OthelloNet, self).__init__()
# Input: 3 channels (Player pieces, Opponent pieces, Legal moves/Constant plane)
self.conv_input = nn.Conv2d(3, num_channels, kernel_size=3, padding=1, bias=False)
self.bn_input = nn.BatchNorm2d(num_channels)
# Residual Tower
self.res_blocks = nn.ModuleList([
ResidualBlock(num_channels) for _ in range(num_res_blocks)
])
# Policy Head
self.policy_conv = nn.Conv2d(num_channels, 2, kernel_size=1, bias=False)
self.policy_bn = nn.BatchNorm2d(2)
# 2 channels * 8 * 8 = 128
self.policy_fc = nn.Linear(128, 65) # 64 squares + pass
# Value Head
self.value_conv = nn.Conv2d(num_channels, 1, kernel_size=1, bias=False)
self.value_bn = nn.BatchNorm2d(1)
# 1 channel * 8 * 8 = 64
self.value_fc1 = nn.Linear(64, 256)
self.value_fc2 = nn.Linear(256, 1)
def forward(self, x):
# Input Convolution
x = F.relu(self.bn_input(self.conv_input(x)))
# Residual Tower
for block in self.res_blocks:
x = block(x)
# Policy Head
p = F.relu(self.policy_bn(self.policy_conv(x)))
p = p.view(p.size(0), -1) # Flatten
p = self.policy_fc(p)
# We return logits (unnormalized), let loss function handle softma separation
# Or return log_softmax for NLLLoss if needed.
# Often for alpha zero implementations, returning log_softmax for training stability is good
# But here let's stick to returning raw logits (or log_softmax)
# Let's return log_softmax as it is numerically stable for KLDivLoss
p = F.log_softmax(p, dim=1)
# Value Head
v = F.relu(self.value_bn(self.value_conv(x)))
v = v.view(v.size(0), -1) # Flatten
v = F.relu(self.value_fc1(v))
v = torch.tanh(self.value_fc2(v))
return p, v