File size: 2,187 Bytes
c97d40f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import hashlib
import torch
from loguru import logger
class MultiInputConv(torch.nn.Module):
@logger.catch
def __init__(self):
super().__init__()
self.flatten = torch.nn.Flatten()
self.conv_long = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=15, padding=7, stride=2),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2),
torch.nn.LeakyReLU(),
)
self.conv_middle = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=9, padding=4, stride=2),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2),
torch.nn.LeakyReLU(),
)
self.conv_short = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=5, padding=2, stride=2),
torch.nn.LeakyReLU(),
torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2),
torch.nn.LeakyReLU(),
)
self.linear_relu_stack = torch.nn.Sequential(
torch.nn.Linear(in_features=(4 * 2 * 2) + (4 * 2 * 2) + (4 * 2 * 2) + 1 + 4, out_features=16),
torch.nn.LeakyReLU(),
torch.nn.Linear(in_features=16, out_features=1),
)
@logger.catch
def forward(self, x):
board, color, castling = x
board = board.float()
color = color.float()
castling = castling.float()
long = self.conv_long(board)
long = self.flatten(long)
middle = self.conv_middle(board)
middle = self.flatten(middle)
short = self.conv_short(board)
short = self.flatten(short)
x = torch.cat((long, middle, short, color, castling), dim=1)
score = self.linear_relu_stack(x)
return score
@logger.catch
def model_hash(self) -> str:
"""Get the hash of the model."""
return hashlib.md5(
(str(self.linear_relu_stack) + str(self.flatten)).encode()
).hexdigest()
|