import hashlib import torch from loguru import logger class MultiInputConv(torch.nn.Module): @logger.catch def __init__(self): super().__init__() self.flatten = torch.nn.Flatten() self.conv_long = torch.nn.Sequential( torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=15, padding=7, stride=2), torch.nn.LeakyReLU(), torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2), torch.nn.LeakyReLU(), ) self.conv_middle = torch.nn.Sequential( torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=9, padding=4, stride=2), torch.nn.LeakyReLU(), torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2), torch.nn.LeakyReLU(), ) self.conv_short = torch.nn.Sequential( torch.nn.Conv2d(in_channels=12, out_channels=16, kernel_size=5, padding=2, stride=2), torch.nn.LeakyReLU(), torch.nn.Conv2d(in_channels=16, out_channels=4, kernel_size=7, padding=3, stride=2), torch.nn.LeakyReLU(), ) self.linear_relu_stack = torch.nn.Sequential( torch.nn.Linear(in_features=(4 * 2 * 2) + (4 * 2 * 2) + (4 * 2 * 2) + 1 + 4, out_features=16), torch.nn.LeakyReLU(), torch.nn.Linear(in_features=16, out_features=1), ) @logger.catch def forward(self, x): board, color, castling = x board = board.float() color = color.float() castling = castling.float() long = self.conv_long(board) long = self.flatten(long) middle = self.conv_middle(board) middle = self.flatten(middle) short = self.conv_short(board) short = self.flatten(short) x = torch.cat((long, middle, short, color, castling), dim=1) score = self.linear_relu_stack(x) return score @logger.catch def model_hash(self) -> str: """Get the hash of the model.""" return hashlib.md5( (str(self.linear_relu_stack) + str(self.flatten)).encode() ).hexdigest()