|
import sys |
|
from .transformer import ViT |
|
sys.path.append("/".join(__file__.split('/')[:-2])) |
|
from params_model import * |
|
from params_data import * |
|
|
|
from collections import OrderedDict |
|
from torch import nn |
|
import torch |
|
|
|
class ConvBlock(nn.Sequential): |
|
def __init__(self, in_channels, out_channels, kernel_size, padding=0): |
|
super().__init__(OrderedDict([ |
|
('conv', nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False)), |
|
('bn', nn.BatchNorm2d(out_channels)), |
|
('relu', nn.ReLU(inplace=True)), |
|
])) |
|
|
|
class SqueezeExcitation(nn.Module): |
|
def __init__(self, channels, ratio): |
|
super().__init__() |
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
self.lin1 = nn.Linear(channels, channels // ratio) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.lin2 = nn.Linear(channels // ratio, 2 * channels) |
|
|
|
def forward(self, x): |
|
n, c, h, w = x.size() |
|
x_in = x |
|
|
|
x = self.pool(x).view(n, c) |
|
x = self.lin1(x) |
|
x = self.relu(x) |
|
x = self.lin2(x) |
|
|
|
x = x.view(n, 2 * c, 1, 1) |
|
scale, shift = x.chunk(2, dim=1) |
|
|
|
x = scale.sigmoid() * x_in + shift |
|
return x |
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, channels, se_ratio): |
|
super().__init__() |
|
self.layers = nn.Sequential(OrderedDict([ |
|
('conv1', nn.Conv2d(channels, channels, 3, padding=1, bias=False)), |
|
('bn1', nn.BatchNorm2d(channels)), |
|
('relu', nn.ReLU(inplace=True)), |
|
|
|
('conv2', nn.Conv2d(channels, channels, 3, padding=1, bias=False)), |
|
('bn2', nn.BatchNorm2d(channels)), |
|
|
|
('se', SqueezeExcitation(channels, se_ratio)), |
|
])) |
|
self.relu2 = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
x_in = x |
|
|
|
x = self.layers(x) |
|
|
|
x = x + x_in |
|
x = self.relu2(x) |
|
return x |
|
|
|
class Encoder(nn.Module): |
|
|
|
def __init__(self, loss_device, loss_method = "softmax"): |
|
super().__init__() |
|
self.loss_device = loss_device |
|
|
|
channels = residual_channels |
|
|
|
self.conv_block = ConvBlock(34, channels, 3, padding=1) |
|
blocks = [(f'block{i+1}', ResidualBlock(channels, se_ratio)) for i in range(residual_blocks)] |
|
self.residual_stack = nn.Sequential(OrderedDict(blocks)) |
|
|
|
self.conv_block2 = ConvBlock(channels, channels, 3, padding=1) |
|
self.final_feature = ConvBlock(channels, vit_input_channels, 3, padding=1) |
|
self.global_avgpool = nn.AvgPool2d(kernel_size=8) |
|
|
|
self.cnn = nn.Sequential(*[ |
|
self.conv_block, |
|
self.residual_stack, |
|
self.conv_block2, |
|
self.final_feature, |
|
self.global_avgpool, |
|
torch.nn.Flatten() |
|
]) |
|
|
|
self.transformer = ViT(input_dim=vit_input_channels, |
|
output_dim=model_embedding_size, |
|
dim=transformer_input_dim, |
|
depth=transformer_depth, |
|
heads=attention_heads, |
|
mlp_dim=mlp_dim, |
|
pool='mean', |
|
dim_head = dim_head, |
|
dropout=dropout, |
|
emb_dropout=emb_dropout) |
|
|
|
|
|
self.similarity_weight = nn.Parameter(torch.tensor([similarity_weight_init])) |
|
self.similarity_bias = nn.Parameter(torch.tensor([similarity_bias_init])) |
|
|
|
def forward(self, games): |
|
|
|
batch_size, n_frames, feature_shape = games.shape[0], games.shape[1], games.shape[2:] |
|
|
|
|
|
games = torch.reshape(games, (batch_size*n_frames, *feature_shape)) |
|
|
|
|
|
game_features = self.cnn(games) |
|
|
|
|
|
game_features = torch.reshape(game_features, (batch_size, n_frames, game_features.shape[-1])) |
|
|
|
|
|
|
|
embeds_raw = self.transformer(game_features) |
|
|
|
|
|
|
|
embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) |
|
|
|
return embeds |