import torch from torch import nn import torch.nn.functional as F from utils import NestedTensor, nested_tensor_from_tensor_list from backbone import build_backbone from transformer import build_transformer class Caption(nn.Module): def __init__(self, backbone, transformer, hidden_dim, vocab_size): super().__init__() self.backbone = backbone self.input_proj = nn.Conv2d( backbone.num_channels, hidden_dim, kernel_size=1) self.transformer = transformer self.mlp = MLP(hidden_dim, 512, vocab_size, 3) def forward(self, samples, target, target_mask): if not isinstance(samples, NestedTensor): samples = nested_tensor_from_tensor_list(samples) features, pos = self.backbone(samples) src, mask = features[-1].decompose() assert mask is not None hs = self.transformer(self.input_proj(src), mask, pos[-1], target, target_mask) out = self.mlp(hs.permute(1, 0, 2)) return out class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def build_model(config): backbone = build_backbone(config) transformer = build_transformer(config) model = Caption(backbone, transformer, config.hidden_dim, config.vocab_size) criterion = torch.nn.CrossEntropyLoss() return model, criterion