|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from utils import NestedTensor, nested_tensor_from_tensor_list |
|
from backbone import build_backbone |
|
from transformer import build_transformer |
|
|
|
|
|
class Caption(nn.Module): |
|
def __init__(self, backbone, transformer, hidden_dim, vocab_size): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.input_proj = nn.Conv2d( |
|
backbone.num_channels, hidden_dim, kernel_size=1) |
|
self.transformer = transformer |
|
self.mlp = MLP(hidden_dim, 512, vocab_size, 3) |
|
|
|
def forward(self, samples, target, target_mask): |
|
if not isinstance(samples, NestedTensor): |
|
samples = nested_tensor_from_tensor_list(samples) |
|
|
|
features, pos = self.backbone(samples) |
|
src, mask = features[-1].decompose() |
|
|
|
assert mask is not None |
|
|
|
hs = self.transformer(self.input_proj(src), mask, |
|
pos[-1], target, target_mask) |
|
out = self.mlp(hs.permute(1, 0, 2)) |
|
return out |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) |
|
for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
def build_model(config): |
|
backbone = build_backbone(config) |
|
transformer = build_transformer(config) |
|
|
|
model = Caption(backbone, transformer, config.hidden_dim, config.vocab_size) |
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
return model, criterion |