|
|
from torch import nn |
|
|
|
|
|
|
|
|
class BasicBlock(nn.Module): |
|
|
def __init__(self, in_channels, channels, bias, k=3, p=1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(in_channels, channels, k, stride=1, padding=p, bias=bias) |
|
|
self.bn1 = nn.BatchNorm2d(channels) |
|
|
self.relu1 = nn.ReLU() |
|
|
self.conv2 = nn.Conv2d(channels, channels, k, stride=1, padding=p, bias=bias) |
|
|
self.bn2 = nn.BatchNorm2d(channels) |
|
|
self.relu2 = nn.ReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.conv1(x) |
|
|
y = self.bn1(y) |
|
|
y = self.relu1(y) |
|
|
y = self.conv2(y) |
|
|
y = self.bn2(y) |
|
|
x = x + y |
|
|
x = self.relu2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
def __init__(self, in_channels, channels, bias): |
|
|
super().__init__() |
|
|
mid_channels = channels // 2 |
|
|
self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias) |
|
|
self.bn1 = nn.BatchNorm2d(mid_channels) |
|
|
self.relu1 = nn.ReLU() |
|
|
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) |
|
|
self.bn2 = nn.BatchNorm2d(mid_channels) |
|
|
self.relu2 = nn.ReLU() |
|
|
self.conv3 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias) |
|
|
self.bn3 = nn.BatchNorm2d(channels) |
|
|
self.relu3 = nn.ReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.conv1(x) |
|
|
y = self.bn1(y) |
|
|
y = self.relu1(y) |
|
|
y = self.conv2(y) |
|
|
y = self.bn2(y) |
|
|
y = self.relu2(y) |
|
|
y = self.conv3(y) |
|
|
y = self.bn3(y) |
|
|
x = x + y |
|
|
x = self.relu3(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Bottlenest(nn.Module): |
|
|
def __init__(self, in_channels, channels, bias): |
|
|
super().__init__() |
|
|
mid_channels = channels // 2 |
|
|
self.conv0 = nn.Conv2d(in_channels, mid_channels, 1, 1, bias=bias) |
|
|
self.bn0 = nn.BatchNorm2d(mid_channels) |
|
|
self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) |
|
|
self.bn1 = nn.BatchNorm2d(mid_channels) |
|
|
self.relu1 = nn.ReLU() |
|
|
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) |
|
|
self.bn2 = nn.BatchNorm2d(mid_channels) |
|
|
self.relu2 = nn.ReLU() |
|
|
self.conv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) |
|
|
self.bn3 = nn.BatchNorm2d(mid_channels) |
|
|
self.relu3 = nn.ReLU() |
|
|
self.conv4 = nn.Conv2d(mid_channels, mid_channels, 3, 1, padding=1, bias=bias) |
|
|
self.bn4 = nn.BatchNorm2d(mid_channels) |
|
|
self.relu4 = nn.ReLU() |
|
|
self.conv5 = nn.Conv2d(mid_channels, channels, 1, 1, bias=bias) |
|
|
self.bn5 = nn.BatchNorm2d(channels) |
|
|
self.relu5 = nn.ReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
y = self.conv0(x) |
|
|
y = self.bn0(y) |
|
|
z = self.conv1(y) |
|
|
z = self.bn1(z) |
|
|
z = self.relu1(z) |
|
|
z = self.conv2(z) |
|
|
z = self.bn2(z) |
|
|
y = y + z |
|
|
y = self.relu2(y) |
|
|
z = self.conv3(y) |
|
|
z = self.bn3(z) |
|
|
z = self.relu3(z) |
|
|
z = self.conv4(z) |
|
|
z = self.bn4(z) |
|
|
y = y + z |
|
|
y = self.relu4(y) |
|
|
y = self.conv5(y) |
|
|
y = self.bn5(y) |
|
|
x = x + y |
|
|
x = self.relu5(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ResNet(nn.Module): |
|
|
def __init__(self, block, in_channels, layers, channels, bias): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
in_channels, channels, kernel_size=5, stride=1, padding=2, bias=bias |
|
|
), |
|
|
nn.BatchNorm2d(channels), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.convs = nn.ModuleList( |
|
|
[block(channels, channels, bias) for _ in range(layers)] |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv1(x) |
|
|
for conv in self.convs: |
|
|
x = conv(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class AlphaZero(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
layers, |
|
|
channels, |
|
|
moves, |
|
|
board_size, |
|
|
value_heads=1, |
|
|
bias=False, |
|
|
block=BasicBlock, |
|
|
): |
|
|
super().__init__() |
|
|
self.board_size = board_size |
|
|
self.resnet = ResNet(block, in_channels, layers, channels, bias) |
|
|
|
|
|
self.policy_head_front = nn.Sequential( |
|
|
nn.Conv2d(channels, 2, 1), |
|
|
nn.BatchNorm2d(2), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.policy_head_end = nn.Linear(2 * board_size, moves) |
|
|
|
|
|
self.value_head_front = nn.Sequential( |
|
|
nn.Conv2d(channels, 1, 1), |
|
|
nn.BatchNorm2d(1), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.value_head_end = nn.Sequential( |
|
|
nn.Linear(board_size, channels), |
|
|
nn.ReLU(), |
|
|
nn.Linear(channels, value_heads), |
|
|
nn.Tanh(), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.resnet(x) |
|
|
|
|
|
p = self.policy_head_front(x) |
|
|
p = p.view(-1, 2 * self.board_size) |
|
|
p = self.policy_head_end(p) |
|
|
|
|
|
v = self.value_head_front(x) |
|
|
v = v.view(-1, self.board_size) |
|
|
v = self.value_head_end(v) |
|
|
return p, v |
|
|
|